From 3111aeab0e3453774690d70a3b1d7c24b5647080 Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Tue, 12 May 2015 15:57:08 +0000 Subject: [PATCH] LUCENE-6472: Added min and max document options to global ordinal join git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1678989 13f79535-47bb-0310-9956-ffa450edef68 --- lucene/CHANGES.txt | 6 +- .../search/join/BaseGlobalOrdinalScorer.java | 4 +- .../search/join/GlobalOrdinalsQuery.java | 9 +- .../GlobalOrdinalsWithScoreCollector.java | 126 +++++++++++++++--- .../join/GlobalOrdinalsWithScoreQuery.java | 43 +++--- .../apache/lucene/search/join/JoinUtil.java | 81 ++++++++--- .../lucene/search/join/TestJoinUtil.java | 73 +++++++++- 7 files changed, 278 insertions(+), 64 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 0f373ba4235..81a5d60386b 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -54,9 +54,9 @@ New Features queries, and supports two-phased iterators to avoid loading positions when possible. (Paul Elschot via Robert Muir) -* LUCENE-6352: Added a new query time join to the join module that uses - global ordinals, which is faster for subsequent joins between reopens. - (Martijn van Groningen, Adrien Grand) +* LUCENE-6352, LUCENE-6472: Added a new query time join to the join module + that uses global ordinals, which is faster for subsequent joins between + reopens. (Martijn van Groningen, Adrien Grand) * LUCENE-5879: Added experimental auto-prefix terms to BlockTree terms dictionary, exposed as AutoPrefixPostingsFormat (Adrien Grand, diff --git a/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java b/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java index 4d81d58f92a..e04e275eca9 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java @@ -28,15 +28,13 @@ import java.io.IOException; abstract class BaseGlobalOrdinalScorer extends Scorer { - final LongBitSet foundOrds; final SortedDocValues values; final Scorer approximationScorer; float score; - public BaseGlobalOrdinalScorer(Weight weight, LongBitSet foundOrds, SortedDocValues values, Scorer approximationScorer) { + public BaseGlobalOrdinalScorer(Weight weight, SortedDocValues values, Scorer approximationScorer) { super(weight); - this.foundOrds = foundOrds; this.values = values; this.approximationScorer = approximationScorer; } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java index f292f53e9b4..19b908f7902 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java @@ -160,11 +160,13 @@ final class GlobalOrdinalsQuery extends Query { final static class OrdinalMapScorer extends BaseGlobalOrdinalScorer { + final LongBitSet foundOrds; final LongValues segmentOrdToGlobalOrdLookup; public OrdinalMapScorer(Weight weight, float score, LongBitSet foundOrds, SortedDocValues values, Scorer approximationScorer, LongValues segmentOrdToGlobalOrdLookup) { - super(weight, foundOrds, values, approximationScorer); + super(weight, values, approximationScorer); this.score = score; + this.foundOrds = foundOrds; this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup; } @@ -203,9 +205,12 @@ final class GlobalOrdinalsQuery extends Query { final static class SegmentOrdinalScorer extends BaseGlobalOrdinalScorer { + final LongBitSet foundOrds; + public SegmentOrdinalScorer(Weight weight, float score, LongBitSet foundOrds, SortedDocValues values, Scorer approximationScorer) { - super(weight, foundOrds, values, approximationScorer); + super(weight, values, approximationScorer); this.score = score; + this.foundOrds = foundOrds; } @Override diff --git a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreCollector.java b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreCollector.java index 9252b56b12c..b02d6e59dfd 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreCollector.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreCollector.java @@ -33,23 +33,48 @@ import java.util.Arrays; abstract class GlobalOrdinalsWithScoreCollector implements Collector { final String field; + final boolean doMinMax; + final int min; + final int max; final MultiDocValues.OrdinalMap ordinalMap; final LongBitSet collectedOrds; - protected final Scores scores; - GlobalOrdinalsWithScoreCollector(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) { + protected final Scores scores; + protected final Occurrences occurrences; + + GlobalOrdinalsWithScoreCollector(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount, ScoreMode scoreMode, int min, int max) { if (valueCount > Integer.MAX_VALUE) { // We simply don't support more than throw new IllegalStateException("Can't collect more than [" + Integer.MAX_VALUE + "] ids"); } this.field = field; + this.doMinMax = !(min <= 0 && max == Integer.MAX_VALUE); + this.min = min; + this.max = max;; this.ordinalMap = ordinalMap; this.collectedOrds = new LongBitSet(valueCount); - this.scores = new Scores(valueCount, unset()); + if (scoreMode != ScoreMode.None) { + this.scores = new Scores(valueCount, unset()); + } else { + this.scores = null; + } + if (scoreMode == ScoreMode.Avg || doMinMax) { + this.occurrences = new Occurrences(valueCount); + } else { + this.occurrences = null; + } } - public LongBitSet getCollectorOrdinals() { - return collectedOrds; + public boolean match(int globalOrd) { + if (collectedOrds.get(globalOrd)) { + if (doMinMax) { + final int occurrence = occurrences.getOccurrence(globalOrd); + return occurrence >= min && occurrence <= max; + } else { + return true; + } + } + return false; } public float score(int globalOrdinal) { @@ -96,6 +121,9 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector { float existingScore = scores.getScore(globalOrd); float newScore = scorer.score(); doScore(globalOrd, existingScore, newScore); + if (occurrences != null) { + occurrences.increment(globalOrd); + } } } @@ -122,6 +150,9 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector { float existingScore = scores.getScore(segmentOrd); float newScore = scorer.score(); doScore(segmentOrd, existingScore, newScore); + if (occurrences != null) { + occurrences.increment(segmentOrd); + } } } @@ -133,8 +164,8 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector { static final class Min extends GlobalOrdinalsWithScoreCollector { - public Min(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) { - super(field, ordinalMap, valueCount); + public Min(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount, int min, int max) { + super(field, ordinalMap, valueCount, ScoreMode.Min, min, max); } @Override @@ -150,8 +181,8 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector { static final class Max extends GlobalOrdinalsWithScoreCollector { - public Max(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) { - super(field, ordinalMap, valueCount); + public Max(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount, int min, int max) { + super(field, ordinalMap, valueCount, ScoreMode.Max, min, max); } @Override @@ -167,8 +198,8 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector { static final class Sum extends GlobalOrdinalsWithScoreCollector { - public Sum(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) { - super(field, ordinalMap, valueCount); + public Sum(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount, int min, int max) { + super(field, ordinalMap, valueCount, ScoreMode.Total, min, max); } @Override @@ -184,16 +215,12 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector { static final class Avg extends GlobalOrdinalsWithScoreCollector { - private final Occurrences occurrences; - - public Avg(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) { - super(field, ordinalMap, valueCount); - this.occurrences = new Occurrences(valueCount); + public Avg(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount, int min, int max) { + super(field, ordinalMap, valueCount, ScoreMode.Avg, min, max); } @Override protected void doScore(int globalOrd, float existingScore, float newScore) { - occurrences.increment(globalOrd); scores.setScore(globalOrd, existingScore + newScore); } @@ -208,6 +235,71 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector { } } + static final class NoScore extends GlobalOrdinalsWithScoreCollector { + + public NoScore(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount, int min, int max) { + super(field, ordinalMap, valueCount, ScoreMode.None, min, max); + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + SortedDocValues docTermOrds = DocValues.getSorted(context.reader(), field); + if (ordinalMap != null) { + LongValues segmentOrdToGlobalOrdLookup = ordinalMap.getGlobalOrds(context.ord); + return new LeafCollector() { + + @Override + public void setScorer(Scorer scorer) throws IOException { + } + + @Override + public void collect(int doc) throws IOException { + final long segmentOrd = docTermOrds.getOrd(doc); + if (segmentOrd != -1) { + final int globalOrd = (int) segmentOrdToGlobalOrdLookup.get(segmentOrd); + collectedOrds.set(globalOrd); + occurrences.increment(globalOrd); + } + } + }; + } else { + return new LeafCollector() { + @Override + public void setScorer(Scorer scorer) throws IOException { + } + + @Override + public void collect(int doc) throws IOException { + final int segmentOrd = docTermOrds.getOrd(doc); + if (segmentOrd != -1) { + collectedOrds.set(segmentOrd); + occurrences.increment(segmentOrd); + } + } + }; + } + } + + @Override + protected void doScore(int globalOrd, float existingScore, float newScore) { + } + + @Override + public float score(int globalOrdinal) { + return 1f; + } + + @Override + protected float unset() { + return 0f; + } + + @Override + public boolean needsScores() { + return false; + } + } + // Because the global ordinal is directly used as a key to a score we should be somewhat smart about allocation // the scores array. Most of the times not all docs match so splitting the scores array up in blocks can prevent creation of huge arrays. // Also working with smaller arrays is supposed to be more gc friendly diff --git a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java index 093475b8f53..c1bcb646063 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java @@ -17,9 +17,6 @@ package org.apache.lucene.search.join; * limitations under the License. */ -import java.io.IOException; -import java.util.Set; - import org.apache.lucene.index.DocValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; @@ -37,6 +34,9 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LongValues; +import java.io.IOException; +import java.util.Set; + final class GlobalOrdinalsWithScoreQuery extends Query { private final GlobalOrdinalsWithScoreCollector collector; @@ -47,14 +47,18 @@ final class GlobalOrdinalsWithScoreQuery extends Query { // just for hashcode and equals: private final Query fromQuery; + private final int min; + private final int max; private final IndexReader indexReader; - GlobalOrdinalsWithScoreQuery(GlobalOrdinalsWithScoreCollector collector, String joinField, MultiDocValues.OrdinalMap globalOrds, Query toQuery, Query fromQuery, IndexReader indexReader) { + GlobalOrdinalsWithScoreQuery(GlobalOrdinalsWithScoreCollector collector, String joinField, MultiDocValues.OrdinalMap globalOrds, Query toQuery, Query fromQuery, int min, int max, IndexReader indexReader) { this.collector = collector; this.joinField = joinField; this.globalOrds = globalOrds; this.toQuery = toQuery; this.fromQuery = fromQuery; + this.min = min; + this.max = max; this.indexReader = indexReader; } @@ -71,8 +75,10 @@ final class GlobalOrdinalsWithScoreQuery extends Query { GlobalOrdinalsWithScoreQuery that = (GlobalOrdinalsWithScoreQuery) o; - if (!fromQuery.equals(that.fromQuery)) return false; + if (min != that.min) return false; + if (max != that.max) return false; if (!joinField.equals(that.joinField)) return false; + if (!fromQuery.equals(that.fromQuery)) return false; if (!toQuery.equals(that.toQuery)) return false; if (!indexReader.equals(that.indexReader)) return false; @@ -85,6 +91,8 @@ final class GlobalOrdinalsWithScoreQuery extends Query { result = 31 * result + joinField.hashCode(); result = 31 * result + toQuery.hashCode(); result = 31 * result + fromQuery.hashCode(); + result = 31 * result + min; + result = 31 * result + max; result = 31 * result + indexReader.hashCode(); return result; } @@ -92,7 +100,10 @@ final class GlobalOrdinalsWithScoreQuery extends Query { @Override public String toString(String field) { return "GlobalOrdinalsQuery{" + - "joinField=" + joinField + + "joinField=" + joinField + + "min=" + min + + "max=" + max + + "fromQuery=" + fromQuery + '}'; } @@ -168,7 +179,7 @@ final class GlobalOrdinalsWithScoreQuery extends Query { final GlobalOrdinalsWithScoreCollector collector; public OrdinalMapScorer(Weight weight, GlobalOrdinalsWithScoreCollector collector, SortedDocValues values, Scorer approximationScorer, LongValues segmentOrdToGlobalOrdLookup) { - super(weight, collector.getCollectorOrdinals(), values, approximationScorer); + super(weight, values, approximationScorer); this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup; this.collector = collector; } @@ -178,9 +189,9 @@ final class GlobalOrdinalsWithScoreQuery extends Query { for (int docID = approximationScorer.advance(target); docID < NO_MORE_DOCS; docID = approximationScorer.nextDoc()) { final long segmentOrd = values.getOrd(docID); if (segmentOrd != -1) { - final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd); - if (foundOrds.get(globalOrd)) { - score = collector.score((int) globalOrd); + final int globalOrd = (int) segmentOrdToGlobalOrdLookup.get(segmentOrd); + if (collector.match(globalOrd)) { + score = collector.score(globalOrd); return docID; } } @@ -196,9 +207,9 @@ final class GlobalOrdinalsWithScoreQuery extends Query { public boolean matches() throws IOException { final long segmentOrd = values.getOrd(approximationScorer.docID()); if (segmentOrd != -1) { - final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd); - if (foundOrds.get(globalOrd)) { - score = collector.score((int) globalOrd); + final int globalOrd = (int) segmentOrdToGlobalOrdLookup.get(segmentOrd); + if (collector.match(globalOrd)) { + score = collector.score(globalOrd); return true; } } @@ -214,7 +225,7 @@ final class GlobalOrdinalsWithScoreQuery extends Query { final GlobalOrdinalsWithScoreCollector collector; public SegmentOrdinalScorer(Weight weight, GlobalOrdinalsWithScoreCollector collector, SortedDocValues values, Scorer approximationScorer) { - super(weight, collector.getCollectorOrdinals(), values, approximationScorer); + super(weight, values, approximationScorer); this.collector = collector; } @@ -223,7 +234,7 @@ final class GlobalOrdinalsWithScoreQuery extends Query { for (int docID = approximationScorer.advance(target); docID < NO_MORE_DOCS; docID = approximationScorer.nextDoc()) { final int segmentOrd = values.getOrd(docID); if (segmentOrd != -1) { - if (foundOrds.get(segmentOrd)) { + if (collector.match(segmentOrd)) { score = collector.score(segmentOrd); return docID; } @@ -240,7 +251,7 @@ final class GlobalOrdinalsWithScoreQuery extends Query { public boolean matches() throws IOException { final int segmentOrd = values.getOrd(approximationScorer.docID()); if (segmentOrd != -1) { - if (foundOrds.get(segmentOrd)) { + if (collector.match(segmentOrd)) { score = collector.score(segmentOrd); return true; } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java b/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java index 5c1dc65c3f1..5ab2430122d 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java @@ -29,7 +29,7 @@ import java.io.IOException; import java.util.Locale; /** - * Utility for query time joining using TermsQuery and TermsCollector. + * Utility for query time joining. * * @lucene.experimental */ @@ -97,17 +97,10 @@ public final class JoinUtil { } /** - * A query time join using global ordinals over a dedicated join field. + * Delegates to {@link #createJoinQuery(String, Query, Query, IndexSearcher, ScoreMode, MultiDocValues.OrdinalMap, int, int)}, + * but disables the min and max filtering. * - * This join has certain restrictions and requirements: - * 1) A document can only refer to one other document. (but can be referred by one or more documents) - * 2) Documents on each side of the join must be distinguishable. Typically this can be done by adding an extra field - * that identifies the "from" and "to" side and then the fromQuery and toQuery must take the this into account. - * 3) There must be a single sorted doc values join field used by both the "from" and "to" documents. This join field - * should store the join values as UTF-8 strings. - * 4) An ordinal map must be provided that is created on top of the join field. - * - * @param joinField The {@link org.apache.lucene.index.SortedDocValues} field containing the join values + * @param joinField The {@link SortedDocValues} field containing the join values * @param fromQuery The query containing the actual user query. Also the fromQuery can only match "from" documents. * @param toQuery The query identifying all documents on the "to" side. * @param searcher The index searcher used to execute the from query @@ -123,6 +116,47 @@ public final class JoinUtil { IndexSearcher searcher, ScoreMode scoreMode, MultiDocValues.OrdinalMap ordinalMap) throws IOException { + return createJoinQuery(joinField, fromQuery, toQuery, searcher, scoreMode, ordinalMap, 0, Integer.MAX_VALUE); + } + + /** + * A query time join using global ordinals over a dedicated join field. + * + * This join has certain restrictions and requirements: + * 1) A document can only refer to one other document. (but can be referred by one or more documents) + * 2) Documents on each side of the join must be distinguishable. Typically this can be done by adding an extra field + * that identifies the "from" and "to" side and then the fromQuery and toQuery must take the this into account. + * 3) There must be a single sorted doc values join field used by both the "from" and "to" documents. This join field + * should store the join values as UTF-8 strings. + * 4) An ordinal map must be provided that is created on top of the join field. + * + * Note: min and max filtering and the avg score mode will require this join to keep track of the number of times + * a document matches per join value. This will increase the per join cost in terms of execution time and memory. + * + * @param joinField The {@link SortedDocValues} field containing the join values + * @param fromQuery The query containing the actual user query. Also the fromQuery can only match "from" documents. + * @param toQuery The query identifying all documents on the "to" side. + * @param searcher The index searcher used to execute the from query + * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query + * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment index, no ordinal map + * needs to be provided. + * @param min Optionally the minimum number of "from" documents that are required to match for a "to" document + * to be a match. The min is inclusive. Setting min to 0 and max to Interger.MAX_VALUE + * disables the min and max "from" documents filtering + * @param max Optionally the maximum number of "from" documents that are allowed to match for a "to" document + * to be a match. The max is inclusive. Setting min to 0 and max to Interger.MAX_VALUE + * disables the min and max "from" documents filtering + * @return a {@link Query} instance that can be used to join documents based on the join field + * @throws IOException If I/O related errors occur + */ + public static Query createJoinQuery(String joinField, + Query fromQuery, + Query toQuery, + IndexSearcher searcher, + ScoreMode scoreMode, + MultiDocValues.OrdinalMap ordinalMap, + int min, + int max) throws IOException { IndexReader indexReader = searcher.getIndexReader(); int numSegments = indexReader.leaves().size(); final long valueCount; @@ -146,31 +180,34 @@ public final class JoinUtil { } final Query rewrittenFromQuery = searcher.rewrite(fromQuery); - if (scoreMode == ScoreMode.None) { - GlobalOrdinalsCollector globalOrdinalsCollector = new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount); - searcher.search(rewrittenFromQuery, globalOrdinalsCollector); - return new GlobalOrdinalsQuery(globalOrdinalsCollector.getCollectorOrdinals(), joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader); - } - GlobalOrdinalsWithScoreCollector globalOrdinalsWithScoreCollector; switch (scoreMode) { case Total: - globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount); + globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount, min, max); break; case Min: - globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Min(joinField, ordinalMap, valueCount); + globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Min(joinField, ordinalMap, valueCount, min, max); break; case Max: - globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount); + globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount, min, max); break; case Avg: - globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Avg(joinField, ordinalMap, valueCount); + globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Avg(joinField, ordinalMap, valueCount, min, max); break; + case None: + if (min <= 0 && max == Integer.MAX_VALUE) { + GlobalOrdinalsCollector globalOrdinalsCollector = new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount); + searcher.search(rewrittenFromQuery, globalOrdinalsCollector); + return new GlobalOrdinalsQuery(globalOrdinalsCollector.getCollectorOrdinals(), joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader); + } else { + globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.NoScore(joinField, ordinalMap, valueCount, min, max); + break; + } default: throw new IllegalArgumentException(String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode)); } searcher.search(rewrittenFromQuery, globalOrdinalsWithScoreCollector); - return new GlobalOrdinalsWithScoreQuery(globalOrdinalsWithScoreCollector, joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader); + return new GlobalOrdinalsWithScoreQuery(globalOrdinalsWithScoreCollector, joinField, ordinalMap, toQuery, rewrittenFromQuery, min, max, indexReader); } } diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java b/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java index dd7e3f7dfc8..2e9f0fb1f58 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java @@ -61,6 +61,7 @@ import org.apache.lucene.search.SimpleCollector; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopScoreDocCollector; +import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.util.BitSet; @@ -412,7 +413,7 @@ public class TestJoinUtil extends LuceneTestCase { String childId = Integer.toString(p + c); Document childDoc = new Document(); childDoc.add(new StringField("id", childId, Field.Store.YES)); - parentDoc.add(new StringField("type", "from", Field.Store.NO)); + childDoc.add(new StringField("type", "from", Field.Store.NO)); childDoc.add(new SortedDocValuesField("join_field", new BytesRef(parentId))); int price = random().nextInt(1000); childDoc.add(new NumericDocValuesField(priceField, price)); @@ -459,6 +460,76 @@ public class TestJoinUtil extends LuceneTestCase { dir.close(); } + public void testMinMaxDocs() throws Exception { + Directory dir = newDirectory(); + RandomIndexWriter iw = new RandomIndexWriter( + random(), + dir, + newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false)) + ); + + int minChildDocsPerParent = 2; + int maxChildDocsPerParent = 16; + int numParents = RandomInts.randomIntBetween(random(), 16, 64); + int[] childDocsPerParent = new int[numParents]; + for (int p = 0; p < numParents; p++) { + String parentId = Integer.toString(p); + Document parentDoc = new Document(); + parentDoc.add(new StringField("id", parentId, Field.Store.YES)); + parentDoc.add(new StringField("type", "to", Field.Store.NO)); + parentDoc.add(new SortedDocValuesField("join_field", new BytesRef(parentId))); + iw.addDocument(parentDoc); + int numChildren = RandomInts.randomIntBetween(random(), minChildDocsPerParent, maxChildDocsPerParent); + childDocsPerParent[p] = numChildren; + for (int c = 0; c < numChildren; c++) { + String childId = Integer.toString(p + c); + Document childDoc = new Document(); + childDoc.add(new StringField("id", childId, Field.Store.YES)); + childDoc.add(new StringField("type", "from", Field.Store.NO)); + childDoc.add(new SortedDocValuesField("join_field", new BytesRef(parentId))); + iw.addDocument(childDoc); + } + } + iw.close(); + + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + SortedDocValues[] values = new SortedDocValues[searcher.getIndexReader().leaves().size()]; + for (LeafReaderContext leadContext : searcher.getIndexReader().leaves()) { + values[leadContext.ord] = DocValues.getSorted(leadContext.reader(), "join_field"); + } + MultiDocValues.OrdinalMap ordinalMap = MultiDocValues.OrdinalMap.build( + searcher.getIndexReader().getCoreCacheKey(), values, PackedInts.DEFAULT + ); + Query fromQuery = new TermQuery(new Term("type", "from")); + Query toQuery = new TermQuery(new Term("type", "to")); + + int iters = RandomInts.randomIntBetween(random(), 3, 9); + for (int i = 1; i <= iters; i++) { + final ScoreMode scoreMode = ScoreMode.values()[random().nextInt(ScoreMode.values().length)]; + int min = RandomInts.randomIntBetween(random(), minChildDocsPerParent, maxChildDocsPerParent - 1); + int max = RandomInts.randomIntBetween(random(), min, maxChildDocsPerParent); + if (VERBOSE) { + System.out.println("iter=" + i); + System.out.println("scoreMode=" + scoreMode); + System.out.println("min=" + min); + System.out.println("max=" + max); + } + Query joinQuery = JoinUtil.createJoinQuery("join_field", fromQuery, toQuery, searcher, scoreMode, ordinalMap, min, max); + TotalHitCountCollector collector = new TotalHitCountCollector(); + searcher.search(joinQuery, collector); + int expectedCount = 0; + for (int numChildDocs : childDocsPerParent) { + if (numChildDocs >= min && numChildDocs <= max) { + expectedCount++; + } + } + assertEquals(expectedCount, collector.getTotalHits()); + } + + searcher.getIndexReader().close(); + dir.close(); + } + // TermsWithScoreCollector.MV.Avg forgets to grow beyond TermsWithScoreCollector.INITIAL_ARRAY_SIZE public void testOverflowTermsWithScoreCollector() throws Exception { test300spartans(true, ScoreMode.Avg);