diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index e71149bc586..741418aa23d 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -172,6 +172,10 @@ Improvements
earlier than regular queries in order to improve cache efficiency.
(Adrien Grand)
+* LUCENE-7707: Use predefined shard index when mergeing top docs if present. This
+ allows to use TopDoc#merge to merge shard responses incrementally instead of
+ once all shard responses are present. (Simon Willnauer)
+
Optimizations
* LUCENE-7641: Optimized point range queries to compute documents that do not
diff --git a/lucene/core/src/java/org/apache/lucene/search/ScoreDoc.java b/lucene/core/src/java/org/apache/lucene/search/ScoreDoc.java
index 69464cfdcdc..eb95e298053 100644
--- a/lucene/core/src/java/org/apache/lucene/search/ScoreDoc.java
+++ b/lucene/core/src/java/org/apache/lucene/search/ScoreDoc.java
@@ -28,7 +28,7 @@ public class ScoreDoc {
* @see IndexSearcher#doc(int) */
public int doc;
- /** Only set by {@link TopDocs#merge} */
+ /** Only set by {@link TopDocs#merge}*/
public int shardIndex;
/** Constructs a ScoreDoc. */
diff --git a/lucene/core/src/java/org/apache/lucene/search/TopDocs.java b/lucene/core/src/java/org/apache/lucene/search/TopDocs.java
index c1f825e401b..2913cb2ef6a 100644
--- a/lucene/core/src/java/org/apache/lucene/search/TopDocs.java
+++ b/lucene/core/src/java/org/apache/lucene/search/TopDocs.java
@@ -57,22 +57,54 @@ public class TopDocs {
}
// Refers to one hit:
- private static class ShardRef {
+ private final static class ShardRef {
// Which shard (index into shardHits[]):
final int shardIndex;
+ final boolean useScoreDocIndex;
// Which hit within the shard:
int hitIndex;
- public ShardRef(int shardIndex) {
+ ShardRef(int shardIndex, boolean useScoreDocIndex) {
this.shardIndex = shardIndex;
+ this.useScoreDocIndex = useScoreDocIndex;
}
@Override
public String toString() {
return "ShardRef(shardIndex=" + shardIndex + " hitIndex=" + hitIndex + ")";
}
- };
+
+ int getShardIndex(ScoreDoc scoreDoc) {
+ if (useScoreDocIndex) {
+ assert scoreDoc.shardIndex != -1 : "scoreDoc shardIndex must be predefined set but wasn't";
+ return scoreDoc.shardIndex;
+ } else {
+ assert scoreDoc.shardIndex == -1 : "scoreDoc shardIndex must be undefined but wasn't";
+ return shardIndex;
+ }
+ }
+ }
+
+ /**
+ * if we need to tie-break since score / sort value are the same we first compare shard index (lower shard wins)
+ * and then iff shard index is the same we use the hit index.
+ */
+ static boolean tieBreakLessThan(ShardRef first, ScoreDoc firstDoc, ShardRef second, ScoreDoc secondDoc) {
+ final int firstShardIndex = first.getShardIndex(firstDoc);
+ final int secondShardIndex = second.getShardIndex(secondDoc);
+ // Tie break: earlier shard wins
+ if (firstShardIndex< secondShardIndex) {
+ return true;
+ } else if (firstShardIndex > secondShardIndex) {
+ return false;
+ } else {
+ // Tie break in same shard: resolve however the
+ // shard had resolved it:
+ assert first.hitIndex != second.hitIndex;
+ return first.hitIndex < second.hitIndex;
+ }
+ }
// Specialized MergeSortQueue that just merges by
// relevance score, descending:
@@ -91,25 +123,14 @@ public class TopDocs {
@Override
public boolean lessThan(ShardRef first, ShardRef second) {
assert first != second;
- final float firstScore = shardHits[first.shardIndex][first.hitIndex].score;
- final float secondScore = shardHits[second.shardIndex][second.hitIndex].score;
-
- if (firstScore < secondScore) {
+ ScoreDoc firstScoreDoc = shardHits[first.shardIndex][first.hitIndex];
+ ScoreDoc secondScoreDoc = shardHits[second.shardIndex][second.hitIndex];
+ if (firstScoreDoc.score < secondScoreDoc.score) {
return false;
- } else if (firstScore > secondScore) {
+ } else if (firstScoreDoc.score > secondScoreDoc.score) {
return true;
} else {
- // Tie break: earlier shard wins
- if (first.shardIndex < second.shardIndex) {
- return true;
- } else if (first.shardIndex > second.shardIndex) {
- return false;
- } else {
- // Tie break in same shard: resolve however the
- // shard had resolved it:
- assert first.hitIndex != second.hitIndex;
- return first.hitIndex < second.hitIndex;
- }
+ return tieBreakLessThan(first, firstScoreDoc, second, secondScoreDoc);
}
}
}
@@ -172,27 +193,15 @@ public class TopDocs {
return cmp < 0;
}
}
-
- // Tie break: earlier shard wins
- if (first.shardIndex < second.shardIndex) {
- //System.out.println(" return tb true");
- return true;
- } else if (first.shardIndex > second.shardIndex) {
- //System.out.println(" return tb false");
- return false;
- } else {
- // Tie break in same shard: resolve however the
- // shard had resolved it:
- //System.out.println(" return tb " + (first.hitIndex < second.hitIndex));
- assert first.hitIndex != second.hitIndex;
- return first.hitIndex < second.hitIndex;
- }
+ return tieBreakLessThan(first, firstFD, second, secondFD);
}
}
/** Returns a new TopDocs, containing topN results across
* the provided TopDocs, sorting by score. Each {@link TopDocs}
* instance must be sorted.
+ *
+ * @see #merge(int, int, TopDocs[])
* @lucene.experimental */
public static TopDocs merge(int topN, TopDocs[] shardHits) {
return merge(0, topN, shardHits);
@@ -201,6 +210,10 @@ public class TopDocs {
/**
* Same as {@link #merge(int, TopDocs[])} but also ignores the top
* {@code start} top docs. This is typically useful for pagination.
+ *
+ * Note: This method will fill the {@link ScoreDoc#shardIndex} on all score docs returned iff all ScoreDocs passed
+ * to this have it's shard index set to -1. Otherwise the shard index is not set. This allows to predefine
+ * the shard index in order to incrementally merge shard responses without losing the original shard index.
* @lucene.experimental
*/
public static TopDocs merge(int start, int topN, TopDocs[] shardHits) {
@@ -213,6 +226,7 @@ public class TopDocs {
* the same Sort, and sort field values must have been
* filled (ie, fillFields=true
must be
* passed to {@link TopFieldCollector#create}).
+ * @see #merge(Sort, int, int, TopFieldDocs[])
* @lucene.experimental */
public static TopFieldDocs merge(Sort sort, int topN, TopFieldDocs[] shardHits) {
return merge(sort, 0, topN, shardHits);
@@ -221,6 +235,10 @@ public class TopDocs {
/**
* Same as {@link #merge(Sort, int, TopFieldDocs[])} but also ignores the top
* {@code start} top docs. This is typically useful for pagination.
+ *
+ * Note: This method will fill the {@link ScoreDoc#shardIndex} on all score docs returned iff all ScoreDocs passed
+ * to this have it's shard index set to -1. Otherwise the shard index is not set. This allows to predefine
+ * the shard index in order to incrementally merge shard responses without losing the original shard index.
* @lucene.experimental
*/
public static TopFieldDocs merge(Sort sort, int start, int topN, TopFieldDocs[] shardHits) {
@@ -243,14 +261,26 @@ public class TopDocs {
int totalHitCount = 0;
int availHitCount = 0;
float maxScore = Float.MIN_VALUE;
+ Boolean setShardIndex = null;
for(int shardIDX=0;shardIDX 0) {
+ if (shard.scoreDocs[0].shardIndex == -1) {
+ if (setShardIndex != null && setShardIndex == false) {
+ throw new IllegalStateException("scoreDocs at index " + shardIDX + " has undefined shard indices but previous scoreDocs were predefined");
+ }
+ setShardIndex = true;
+ } else {
+ if (setShardIndex != null && setShardIndex) {
+ throw new IllegalStateException("scoreDocs at index " + shardIDX + " has predefined shard indices but previous scoreDocs were undefined");
+ }
+ setShardIndex = false;
+ }
availHitCount += shard.scoreDocs.length;
- queue.add(new ShardRef(shardIDX));
+ queue.add(new ShardRef(shardIDX, setShardIndex == false));
maxScore = Math.max(maxScore, shard.getMaxScore());
//System.out.println(" maxScore now " + maxScore + " vs " + shard.getMaxScore());
}
@@ -272,7 +302,13 @@ public class TopDocs {
assert queue.size() > 0;
ShardRef ref = queue.top();
final ScoreDoc hit = shardHits[ref.shardIndex].scoreDocs[ref.hitIndex++];
- hit.shardIndex = ref.shardIndex;
+ if (setShardIndex) {
+ // unless this index is already initialized potentially due to multiple merge phases, or explicitly by the user
+ // we set the shard index to the index of the TopDocs array this hit is coming from.
+ // this allows multiple merge phases if needed but requires extra accounting on the users end.
+ // at the same time this is fully backwards compatible since the value was initialize to -1 from the beginning
+ hit.shardIndex = ref.shardIndex;
+ }
if (hitUpto >= start) {
hits[hitUpto - start] = hit;
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTopDocsMerge.java b/lucene/core/src/test/org/apache/lucene/search/TestTopDocsMerge.java
index a5eafad343d..37c61a45135 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestTopDocsMerge.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestTopDocsMerge.java
@@ -30,6 +30,7 @@ import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
@@ -37,7 +38,9 @@ import org.apache.lucene.util.TestUtil;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
public class TestTopDocsMerge extends LuceneTestCase {
@@ -72,6 +75,64 @@ public class TestTopDocsMerge extends LuceneTestCase {
testSort(true);
}
+ public void testInconsistentTopDocsFail() {
+ TopDocs[] topDocs = new TopDocs[] {
+ new TopDocs(1, new ScoreDoc[] { new ScoreDoc(1, 1.0f, 1) }),
+ new TopDocs(1, new ScoreDoc[] { new ScoreDoc(1, 1.0f, -1) })
+ };
+ if (random().nextBoolean()) {
+ ArrayUtil.swap(topDocs, 0, 1);
+ }
+ expectThrows(IllegalStateException.class, () -> {
+ TopDocs.merge(0, 1, topDocs);
+ });
+ }
+
+ public void testAssignShardIndex() {
+ boolean useConstantScore = random().nextBoolean();
+ int numTopDocs = 2 + random().nextInt(10);
+ ArrayList topDocs = new ArrayList<>(numTopDocs);
+ Map shardResultMapping = new HashMap<>();
+ int numHitsTotal = 0;
+ for (int i = 0; i < numTopDocs; i++) {
+ int numHits = 1 + random().nextInt(10);
+ numHitsTotal += numHits;
+ ScoreDoc[] scoreDocs = new ScoreDoc[numHits];
+ for (int j = 0; j < scoreDocs.length; j++) {
+ float score = useConstantScore ? 1.0f : random().nextFloat();
+ scoreDocs[j] = new ScoreDoc((100 * i) + j, score , i);
+ // we set the shard index to index in the list here but shuffle the entire list below
+ }
+ topDocs.add(new TopDocs(numHits, scoreDocs));
+ shardResultMapping.put(i, topDocs.get(i));
+ }
+ // shuffle the entire thing such that we don't get 1 to 1 mapping of shard index to index in the array
+ // -- well likely ;)
+ Collections.shuffle(topDocs, random());
+ final int from = random().nextInt(numHitsTotal-1);
+ final int size = 1 + random().nextInt(numHitsTotal - from);
+ TopDocs merge = TopDocs.merge(from, size, topDocs.toArray(new TopDocs[0]));
+ assertTrue(merge.scoreDocs.length > 0);
+ for (ScoreDoc scoreDoc : merge.scoreDocs) {
+ assertTrue(scoreDoc.shardIndex != -1);
+ TopDocs shardTopDocs = shardResultMapping.get(scoreDoc.shardIndex);
+ assertNotNull(shardTopDocs);
+ boolean found = false;
+ for (ScoreDoc shardScoreDoc : shardTopDocs.scoreDocs) {
+ if (shardScoreDoc == scoreDoc) {
+ found = true;
+ break;
+ }
+ }
+ assertTrue(found);
+ }
+
+ // now ensure merge is stable even if we use our own shard IDs
+ Collections.shuffle(topDocs, random());
+ TopDocs merge2 = TopDocs.merge(from, size, topDocs.toArray(new TopDocs[0]));
+ assertArrayEquals(merge.scoreDocs, merge2.scoreDocs);
+ }
+
void testSort(boolean useFrom) throws Exception {
IndexReader reader = null;