diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index f4e6e3e492e..70eeb2a811e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -32,11 +32,9 @@ import org.elasticsearch.search.sort.MinAndMax; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.transport.Transport; -import java.util.Arrays; import java.util.Comparator; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; import java.util.concurrent.Executor; import java.util.function.BiFunction; @@ -128,7 +126,18 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction[] minAndMaxes) { - return Arrays.stream(minAndMaxes).anyMatch(Objects::nonNull); + Class clazz = null; + for (MinAndMax minAndMax : minAndMaxes) { + if (clazz == null) { + clazz = minAndMax == null ? null : minAndMax.getMin().getClass(); + } else if (minAndMax != null && clazz != minAndMax.getMin().getClass()) { + // we don't support sort values that mix different types (e.g.: long/double, numeric/keyword). + // TODO: we could fail the request because there is a high probability + // that the merging of topdocs will fail later for the same reason ? + return false; + } + } + return clazz != null; } private static Comparator shardComparator(GroupShardsIterator shardsIts, diff --git a/server/src/main/java/org/elasticsearch/search/sort/MinAndMax.java b/server/src/main/java/org/elasticsearch/search/sort/MinAndMax.java index ab02d6df799..703ae6939aa 100644 --- a/server/src/main/java/org/elasticsearch/search/sort/MinAndMax.java +++ b/server/src/main/java/org/elasticsearch/search/sort/MinAndMax.java @@ -55,14 +55,14 @@ public class MinAndMax> implements Writeable { /** * Return the minimum value. */ - T getMin() { + public T getMin() { return minValue; } /** * Return the maximum value. */ - T getMax() { + public T getMax() { return maxValue; } diff --git a/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java index b434a98dd40..f79246be426 100644 --- a/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.action.search; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.OriginalIndices; @@ -54,6 +55,8 @@ import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; +import static org.hamcrest.Matchers.equalTo; + public class CanMatchPreFilterSearchPhaseTests extends ESTestCase { public void testFilterShards() throws InterruptedException { @@ -350,4 +353,76 @@ public class CanMatchPreFilterSearchPhaseTests extends ESTestCase { } } } + + public void testInvalidSortShards() throws InterruptedException { + final TransportSearchAction.SearchTimeProvider timeProvider = + new TransportSearchAction.SearchTimeProvider(0, System.nanoTime(), System::nanoTime); + + Map lookup = new ConcurrentHashMap<>(); + DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT); + DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT); + lookup.put("node1", new SearchAsyncActionTests.MockConnection(primaryNode)); + lookup.put("node2", new SearchAsyncActionTests.MockConnection(replicaNode)); + + for (SortOrder order : SortOrder.values()) { + int numShards = randomIntBetween(2, 20); + List shardIds = new ArrayList<>(); + Set shardToSkip = new HashSet<>(); + + SearchTransportService searchTransportService = new SearchTransportService(null, null) { + @Override + public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task, + ActionListener listener) { + final MinAndMax minMax; + if (request.shardId().id() == numShards-1) { + minMax = new MinAndMax<>(new BytesRef("bar"), new BytesRef("baz")); + } else { + Long min = randomLong(); + Long max = randomLongBetween(min, Long.MAX_VALUE); + minMax = new MinAndMax<>(min, max); + } + boolean canMatch = frequently(); + synchronized (shardIds) { + shardIds.add(request.shardId()); + if (canMatch == false) { + shardToSkip.add(request.shardId()); + } + } + new Thread(() -> listener.onResponse(new SearchService.CanMatchResponse(canMatch, minMax))).start(); + } + }; + + AtomicReference> result = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + GroupShardsIterator shardsIter = SearchAsyncActionTests.getShardsIter("logs", + new OriginalIndices(new String[]{"logs"}, SearchRequest.DEFAULT_INDICES_OPTIONS), + numShards, randomBoolean(), primaryNode, replicaNode); + final SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder().sort(SortBuilders.fieldSort("timestamp").order(order))); + searchRequest.allowPartialSearchResults(true); + + CanMatchPreFilterSearchPhase canMatchPhase = new CanMatchPreFilterSearchPhase(logger, + searchTransportService, + (clusterAlias, node) -> lookup.get(node), + Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)), + Collections.emptyMap(), Collections.emptyMap(), EsExecutors.newDirectExecutorService(), + searchRequest, null, shardsIter, timeProvider, ClusterState.EMPTY_STATE, null, + (iter) -> new SearchPhase("test") { + @Override + public void run() { + result.set(iter); + latch.countDown(); + } + }, SearchResponse.Clusters.EMPTY); + + canMatchPhase.start(); + latch.await(); + int shardId = 0; + for (SearchShardIterator i : result.get()) { + assertThat(i.shardId().id(), equalTo(shardId++)); + assertEquals(shardToSkip.contains(i.shardId()), i.skip()); + } + assertThat(result.get().size(), equalTo(numShards)); + } + } }