diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index e71149bc586..741418aa23d 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -172,6 +172,10 @@ Improvements earlier than regular queries in order to improve cache efficiency. (Adrien Grand) +* LUCENE-7707: Use predefined shard index when mergeing top docs if present. This + allows to use TopDoc#merge to merge shard responses incrementally instead of + once all shard responses are present. (Simon Willnauer) + Optimizations * LUCENE-7641: Optimized point range queries to compute documents that do not diff --git a/lucene/core/src/java/org/apache/lucene/search/ScoreDoc.java b/lucene/core/src/java/org/apache/lucene/search/ScoreDoc.java index 69464cfdcdc..eb95e298053 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ScoreDoc.java +++ b/lucene/core/src/java/org/apache/lucene/search/ScoreDoc.java @@ -28,7 +28,7 @@ public class ScoreDoc { * @see IndexSearcher#doc(int) */ public int doc; - /** Only set by {@link TopDocs#merge} */ + /** Only set by {@link TopDocs#merge}*/ public int shardIndex; /** Constructs a ScoreDoc. */ diff --git a/lucene/core/src/java/org/apache/lucene/search/TopDocs.java b/lucene/core/src/java/org/apache/lucene/search/TopDocs.java index c1f825e401b..2913cb2ef6a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopDocs.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopDocs.java @@ -57,22 +57,54 @@ public class TopDocs { } // Refers to one hit: - private static class ShardRef { + private final static class ShardRef { // Which shard (index into shardHits[]): final int shardIndex; + final boolean useScoreDocIndex; // Which hit within the shard: int hitIndex; - public ShardRef(int shardIndex) { + ShardRef(int shardIndex, boolean useScoreDocIndex) { this.shardIndex = shardIndex; + this.useScoreDocIndex = useScoreDocIndex; } @Override public String toString() { return "ShardRef(shardIndex=" + shardIndex + " hitIndex=" + hitIndex + ")"; } - }; + + int getShardIndex(ScoreDoc scoreDoc) { + if (useScoreDocIndex) { + assert scoreDoc.shardIndex != -1 : "scoreDoc shardIndex must be predefined set but wasn't"; + return scoreDoc.shardIndex; + } else { + assert scoreDoc.shardIndex == -1 : "scoreDoc shardIndex must be undefined but wasn't"; + return shardIndex; + } + } + } + + /** + * if we need to tie-break since score / sort value are the same we first compare shard index (lower shard wins) + * and then iff shard index is the same we use the hit index. + */ + static boolean tieBreakLessThan(ShardRef first, ScoreDoc firstDoc, ShardRef second, ScoreDoc secondDoc) { + final int firstShardIndex = first.getShardIndex(firstDoc); + final int secondShardIndex = second.getShardIndex(secondDoc); + // Tie break: earlier shard wins + if (firstShardIndex< secondShardIndex) { + return true; + } else if (firstShardIndex > secondShardIndex) { + return false; + } else { + // Tie break in same shard: resolve however the + // shard had resolved it: + assert first.hitIndex != second.hitIndex; + return first.hitIndex < second.hitIndex; + } + } // Specialized MergeSortQueue that just merges by // relevance score, descending: @@ -91,25 +123,14 @@ public class TopDocs { @Override public boolean lessThan(ShardRef first, ShardRef second) { assert first != second; - final float firstScore = shardHits[first.shardIndex][first.hitIndex].score; - final float secondScore = shardHits[second.shardIndex][second.hitIndex].score; - - if (firstScore < secondScore) { + ScoreDoc firstScoreDoc = shardHits[first.shardIndex][first.hitIndex]; + ScoreDoc secondScoreDoc = shardHits[second.shardIndex][second.hitIndex]; + if (firstScoreDoc.score < secondScoreDoc.score) { return false; - } else if (firstScore > secondScore) { + } else if (firstScoreDoc.score > secondScoreDoc.score) { return true; } else { - // Tie break: earlier shard wins - if (first.shardIndex < second.shardIndex) { - return true; - } else if (first.shardIndex > second.shardIndex) { - return false; - } else { - // Tie break in same shard: resolve however the - // shard had resolved it: - assert first.hitIndex != second.hitIndex; - return first.hitIndex < second.hitIndex; - } + return tieBreakLessThan(first, firstScoreDoc, second, secondScoreDoc); } } } @@ -172,27 +193,15 @@ public class TopDocs { return cmp < 0; } } - - // Tie break: earlier shard wins - if (first.shardIndex < second.shardIndex) { - //System.out.println(" return tb true"); - return true; - } else if (first.shardIndex > second.shardIndex) { - //System.out.println(" return tb false"); - return false; - } else { - // Tie break in same shard: resolve however the - // shard had resolved it: - //System.out.println(" return tb " + (first.hitIndex < second.hitIndex)); - assert first.hitIndex != second.hitIndex; - return first.hitIndex < second.hitIndex; - } + return tieBreakLessThan(first, firstFD, second, secondFD); } } /** Returns a new TopDocs, containing topN results across * the provided TopDocs, sorting by score. Each {@link TopDocs} * instance must be sorted. + * + * @see #merge(int, int, TopDocs[]) * @lucene.experimental */ public static TopDocs merge(int topN, TopDocs[] shardHits) { return merge(0, topN, shardHits); @@ -201,6 +210,10 @@ public class TopDocs { /** * Same as {@link #merge(int, TopDocs[])} but also ignores the top * {@code start} top docs. This is typically useful for pagination. + * + * Note: This method will fill the {@link ScoreDoc#shardIndex} on all score docs returned iff all ScoreDocs passed + * to this have it's shard index set to -1. Otherwise the shard index is not set. This allows to predefine + * the shard index in order to incrementally merge shard responses without losing the original shard index. * @lucene.experimental */ public static TopDocs merge(int start, int topN, TopDocs[] shardHits) { @@ -213,6 +226,7 @@ public class TopDocs { * the same Sort, and sort field values must have been * filled (ie, fillFields=true must be * passed to {@link TopFieldCollector#create}). + * @see #merge(Sort, int, int, TopFieldDocs[]) * @lucene.experimental */ public static TopFieldDocs merge(Sort sort, int topN, TopFieldDocs[] shardHits) { return merge(sort, 0, topN, shardHits); @@ -221,6 +235,10 @@ public class TopDocs { /** * Same as {@link #merge(Sort, int, TopFieldDocs[])} but also ignores the top * {@code start} top docs. This is typically useful for pagination. + * + * Note: This method will fill the {@link ScoreDoc#shardIndex} on all score docs returned iff all ScoreDocs passed + * to this have it's shard index set to -1. Otherwise the shard index is not set. This allows to predefine + * the shard index in order to incrementally merge shard responses without losing the original shard index. * @lucene.experimental */ public static TopFieldDocs merge(Sort sort, int start, int topN, TopFieldDocs[] shardHits) { @@ -243,14 +261,26 @@ public class TopDocs { int totalHitCount = 0; int availHitCount = 0; float maxScore = Float.MIN_VALUE; + Boolean setShardIndex = null; for(int shardIDX=0;shardIDX 0) { + if (shard.scoreDocs[0].shardIndex == -1) { + if (setShardIndex != null && setShardIndex == false) { + throw new IllegalStateException("scoreDocs at index " + shardIDX + " has undefined shard indices but previous scoreDocs were predefined"); + } + setShardIndex = true; + } else { + if (setShardIndex != null && setShardIndex) { + throw new IllegalStateException("scoreDocs at index " + shardIDX + " has predefined shard indices but previous scoreDocs were undefined"); + } + setShardIndex = false; + } availHitCount += shard.scoreDocs.length; - queue.add(new ShardRef(shardIDX)); + queue.add(new ShardRef(shardIDX, setShardIndex == false)); maxScore = Math.max(maxScore, shard.getMaxScore()); //System.out.println(" maxScore now " + maxScore + " vs " + shard.getMaxScore()); } @@ -272,7 +302,13 @@ public class TopDocs { assert queue.size() > 0; ShardRef ref = queue.top(); final ScoreDoc hit = shardHits[ref.shardIndex].scoreDocs[ref.hitIndex++]; - hit.shardIndex = ref.shardIndex; + if (setShardIndex) { + // unless this index is already initialized potentially due to multiple merge phases, or explicitly by the user + // we set the shard index to the index of the TopDocs array this hit is coming from. + // this allows multiple merge phases if needed but requires extra accounting on the users end. + // at the same time this is fully backwards compatible since the value was initialize to -1 from the beginning + hit.shardIndex = ref.shardIndex; + } if (hitUpto >= start) { hits[hitUpto - start] = hit; } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTopDocsMerge.java b/lucene/core/src/test/org/apache/lucene/search/TestTopDocsMerge.java index a5eafad343d..37c61a45135 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTopDocsMerge.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTopDocsMerge.java @@ -30,6 +30,7 @@ import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.ReaderUtil; import org.apache.lucene.index.Term; import org.apache.lucene.store.Directory; +import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.TestUtil; @@ -37,7 +38,9 @@ import org.apache.lucene.util.TestUtil; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class TestTopDocsMerge extends LuceneTestCase { @@ -72,6 +75,64 @@ public class TestTopDocsMerge extends LuceneTestCase { testSort(true); } + public void testInconsistentTopDocsFail() { + TopDocs[] topDocs = new TopDocs[] { + new TopDocs(1, new ScoreDoc[] { new ScoreDoc(1, 1.0f, 1) }), + new TopDocs(1, new ScoreDoc[] { new ScoreDoc(1, 1.0f, -1) }) + }; + if (random().nextBoolean()) { + ArrayUtil.swap(topDocs, 0, 1); + } + expectThrows(IllegalStateException.class, () -> { + TopDocs.merge(0, 1, topDocs); + }); + } + + public void testAssignShardIndex() { + boolean useConstantScore = random().nextBoolean(); + int numTopDocs = 2 + random().nextInt(10); + ArrayList topDocs = new ArrayList<>(numTopDocs); + Map shardResultMapping = new HashMap<>(); + int numHitsTotal = 0; + for (int i = 0; i < numTopDocs; i++) { + int numHits = 1 + random().nextInt(10); + numHitsTotal += numHits; + ScoreDoc[] scoreDocs = new ScoreDoc[numHits]; + for (int j = 0; j < scoreDocs.length; j++) { + float score = useConstantScore ? 1.0f : random().nextFloat(); + scoreDocs[j] = new ScoreDoc((100 * i) + j, score , i); + // we set the shard index to index in the list here but shuffle the entire list below + } + topDocs.add(new TopDocs(numHits, scoreDocs)); + shardResultMapping.put(i, topDocs.get(i)); + } + // shuffle the entire thing such that we don't get 1 to 1 mapping of shard index to index in the array + // -- well likely ;) + Collections.shuffle(topDocs, random()); + final int from = random().nextInt(numHitsTotal-1); + final int size = 1 + random().nextInt(numHitsTotal - from); + TopDocs merge = TopDocs.merge(from, size, topDocs.toArray(new TopDocs[0])); + assertTrue(merge.scoreDocs.length > 0); + for (ScoreDoc scoreDoc : merge.scoreDocs) { + assertTrue(scoreDoc.shardIndex != -1); + TopDocs shardTopDocs = shardResultMapping.get(scoreDoc.shardIndex); + assertNotNull(shardTopDocs); + boolean found = false; + for (ScoreDoc shardScoreDoc : shardTopDocs.scoreDocs) { + if (shardScoreDoc == scoreDoc) { + found = true; + break; + } + } + assertTrue(found); + } + + // now ensure merge is stable even if we use our own shard IDs + Collections.shuffle(topDocs, random()); + TopDocs merge2 = TopDocs.merge(from, size, topDocs.toArray(new TopDocs[0])); + assertArrayEquals(merge.scoreDocs, merge2.scoreDocs); + } + void testSort(boolean useFrom) throws Exception { IndexReader reader = null;