From 36acada76293302d83cab60c086a70f8c17e17e6 Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Tue, 29 May 2012 20:37:31 +0000 Subject: [PATCH] LUCENE-4043: Added scoring support via score mode for query time joining. git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1343966 13f79535-47bb-0310-9956-ffa450edef68 --- lucene/CHANGES.txt | 3 + .../org/apache/lucene/index/DocTermOrds.java | 7 + .../apache/lucene/search/join/JoinUtil.java | 39 +- .../apache/lucene/search/join/ScoreMode.java | 45 ++ .../search/join/TermsIncludingScoreQuery.java | 271 ++++++++++ .../search/join/TermsWithScoreCollector.java | 292 +++++++++++ .../search/join/ToParentBlockJoinQuery.java | 17 - .../apache/lucene/search/join/package.html | 10 +- .../lucene/search/join/TestBlockJoin.java | 26 +- .../lucene/search/join/TestJoinUtil.java | 480 ++++++++++++++++-- 10 files changed, 1111 insertions(+), 79 deletions(-) create mode 100644 lucene/join/src/java/org/apache/lucene/search/join/ScoreMode.java create mode 100644 lucene/join/src/java/org/apache/lucene/search/join/TermsIncludingScoreQuery.java create mode 100644 lucene/join/src/java/org/apache/lucene/search/join/TermsWithScoreCollector.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index d708fd9ea80..b82405e4134 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -880,6 +880,9 @@ New features returning results after a specified FieldDoc for deep paging. (Mike McCandless) +* LUCENE-4043: Added scoring support via score mode for query time joining. + (Martijn van Groningen, Mike McCandless) + Optimizations * LUCENE-2588: Don't store unnecessary suffixes when writing the terms diff --git a/lucene/core/src/java/org/apache/lucene/index/DocTermOrds.java b/lucene/core/src/java/org/apache/lucene/index/DocTermOrds.java index 4b120b6a47e..7a3750f0522 100644 --- a/lucene/core/src/java/org/apache/lucene/index/DocTermOrds.java +++ b/lucene/core/src/java/org/apache/lucene/index/DocTermOrds.java @@ -220,6 +220,13 @@ public class DocTermOrds { return numTermsInField; } + /** + * @return Whether this DocTermOrds instance is empty. + */ + public boolean isEmpty() { + return index == null; + } + /** Subclass can override this */ protected void visitTerm(TermsEnum te, int termNum) throws IOException { } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java b/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java index 05586cd6ab9..9502d04223c 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java @@ -38,12 +38,24 @@ public final class JoinUtil { *

* Execute the returned query with a {@link IndexSearcher} to retrieve all documents that have the same terms in the * to field that match with documents matching the specified fromQuery and have the same terms in the from field. + *

+ * In the case a single document relates to more than one document the multipleValuesPerDocument option + * should be set to true. When the multipleValuesPerDocument is set to true only the + * the score from the first encountered join value originating from the 'from' side is mapped into the 'to' side. + * Even in the case when a second join value related to a specific document yields a higher score. Obviously this + * doesn't apply in the case that {@link ScoreMode#None} is used, since no scores are computed at all. + *

+ * Memory considerations: During joining all unique join values are kept in memory. On top of that when the scoreMode + * isn't set to {@link ScoreMode#None} a float value per unique join value is kept in memory for computing scores. + * When scoreMode is set to {@link ScoreMode#Avg} also an additional integer value is kept in memory per unique + * join value. * * @param fromField The from field to join from * @param multipleValuesPerDocument Whether the from field has multiple terms per document * @param toField The to field to join to * @param fromQuery The query to match documents on the from side * @param fromSearcher The searcher that executed the specified fromQuery + * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query * @return a {@link Query} instance that can be used to join documents based on the * terms in the from and to field * @throws IOException If I/O related errors occur @@ -52,10 +64,29 @@ public final class JoinUtil { boolean multipleValuesPerDocument, String toField, Query fromQuery, - IndexSearcher fromSearcher) throws IOException { - TermsCollector termsCollector = TermsCollector.create(fromField, multipleValuesPerDocument); - fromSearcher.search(fromQuery, termsCollector); - return new TermsQuery(toField, termsCollector.getCollectorTerms()); + IndexSearcher fromSearcher, + ScoreMode scoreMode) throws IOException { + switch (scoreMode) { + case None: + TermsCollector termsCollector = TermsCollector.create(fromField, multipleValuesPerDocument); + fromSearcher.search(fromQuery, termsCollector); + return new TermsQuery(toField, termsCollector.getCollectorTerms()); + case Total: + case Max: + case Avg: + TermsWithScoreCollector termsWithScoreCollector = + TermsWithScoreCollector.create(fromField, multipleValuesPerDocument, scoreMode); + fromSearcher.search(fromQuery, termsWithScoreCollector); + return new TermsIncludingScoreQuery( + toField, + multipleValuesPerDocument, + termsWithScoreCollector.getCollectedTerms(), + termsWithScoreCollector.getScoresPerTerm(), + fromQuery + ); + default: + throw new IllegalArgumentException(String.format("Score mode %s isn't supported.", scoreMode)); + } } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/ScoreMode.java b/lucene/join/src/java/org/apache/lucene/search/join/ScoreMode.java new file mode 100644 index 00000000000..5b6fc1085c0 --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/ScoreMode.java @@ -0,0 +1,45 @@ +package org.apache.lucene.search.join; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * How to aggregate multiple child hit scores into a single parent score. + */ +public enum ScoreMode { + + /** + * Do no scoring. + */ + None, + + /** + * Parent hit's score is the average of all child scores. + */ + Avg, + + /** + * Parent hit's score is the max of all child scores. + */ + Max, + + /** + * Parent hit's score is the sum of all child scores. + */ + Total + +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/TermsIncludingScoreQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/TermsIncludingScoreQuery.java new file mode 100644 index 00000000000..e6b5d73801d --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/TermsIncludingScoreQuery.java @@ -0,0 +1,271 @@ +package org.apache.lucene.search.join; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.AtomicReaderContext; +import org.apache.lucene.index.DocsEnum; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.ComplexExplanation; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefHash; +import org.apache.lucene.util.FixedBitSet; + +import java.io.IOException; +import java.util.Set; + +class TermsIncludingScoreQuery extends Query { + + final String field; + final boolean multipleValuesPerDocument; + final BytesRefHash terms; + final float[] scores; + final int[] ords; + final Query originalQuery; + final Query unwrittenOriginalQuery; + + TermsIncludingScoreQuery(String field, boolean multipleValuesPerDocument, BytesRefHash terms, float[] scores, Query originalQuery) { + this.field = field; + this.multipleValuesPerDocument = multipleValuesPerDocument; + this.terms = terms; + this.scores = scores; + this.originalQuery = originalQuery; + this.ords = terms.sort(BytesRef.getUTF8SortedAsUnicodeComparator()); + this.unwrittenOriginalQuery = originalQuery; + } + + private TermsIncludingScoreQuery(String field, boolean multipleValuesPerDocument, BytesRefHash terms, float[] scores, int[] ords, Query originalQuery, Query unwrittenOriginalQuery) { + this.field = field; + this.multipleValuesPerDocument = multipleValuesPerDocument; + this.terms = terms; + this.scores = scores; + this.originalQuery = originalQuery; + this.ords = ords; + this.unwrittenOriginalQuery = unwrittenOriginalQuery; + } + + public String toString(String string) { + return String.format("TermsIncludingScoreQuery{field=%s;originalQuery=%s}", field, unwrittenOriginalQuery); + } + + @Override + public void extractTerms(Set terms) { + originalQuery.extractTerms(terms); + } + + @Override + public Query rewrite(IndexReader reader) throws IOException { + final Query originalQueryRewrite = originalQuery.rewrite(reader); + if (originalQueryRewrite != originalQuery) { + Query rewritten = new TermsIncludingScoreQuery(field, multipleValuesPerDocument, terms, scores, + ords, originalQueryRewrite, originalQuery); + rewritten.setBoost(getBoost()); + return rewritten; + } else { + return this; + } + } + + @Override + public Weight createWeight(IndexSearcher searcher) throws IOException { + final Weight originalWeight = originalQuery.createWeight(searcher); + return new Weight() { + + private TermsEnum segmentTermsEnum; + + public Explanation explain(AtomicReaderContext context, int doc) throws IOException { + SVInnerScorer scorer = (SVInnerScorer) scorer(context, true, false, context.reader().getLiveDocs()); + if (scorer != null) { + if (scorer.advance(doc) == doc) { + return scorer.explain(); + } + } + return new ComplexExplanation(false, 0.0f, "Not a match"); + } + + public Query getQuery() { + return TermsIncludingScoreQuery.this; + } + + public float getValueForNormalization() throws IOException { + return originalWeight.getValueForNormalization() * TermsIncludingScoreQuery.this.getBoost() * TermsIncludingScoreQuery.this.getBoost(); + } + + public void normalize(float norm, float topLevelBoost) { + originalWeight.normalize(norm, topLevelBoost * TermsIncludingScoreQuery.this.getBoost()); + } + + public Scorer scorer(AtomicReaderContext context, boolean scoreDocsInOrder, boolean topScorer, Bits acceptDocs) throws IOException { + Terms terms = context.reader().terms(field); + if (terms == null) { + return null; + } + + segmentTermsEnum = terms.iterator(segmentTermsEnum); + if (multipleValuesPerDocument) { + return new MVInnerScorer(this, acceptDocs, segmentTermsEnum, context.reader().maxDoc()); + } else { + return new SVInnerScorer(this, acceptDocs, segmentTermsEnum); + } + } + }; + } + + // This impl assumes that the 'join' values are used uniquely per doc per field. Used for one to many relations. + class SVInnerScorer extends Scorer { + + final BytesRef spare = new BytesRef(); + final Bits acceptDocs; + final TermsEnum termsEnum; + + int upto; + DocsEnum docsEnum; + DocsEnum reuse; + int scoreUpto; + + SVInnerScorer(Weight weight, Bits acceptDocs, TermsEnum termsEnum) { + super(weight); + this.acceptDocs = acceptDocs; + this.termsEnum = termsEnum; + } + + public float score() throws IOException { + return scores[ords[scoreUpto]]; + } + + public Explanation explain() throws IOException { + return new ComplexExplanation(true, score(), "Score based on join value " + termsEnum.term().utf8ToString()); + } + + public int docID() { + return docsEnum != null ? docsEnum.docID() : DocIdSetIterator.NO_MORE_DOCS; + } + + public int nextDoc() throws IOException { + if (docsEnum != null) { + int docId = docsEnum.nextDoc(); + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + docsEnum = null; + } else { + return docId; + } + } + + do { + if (upto == terms.size()) { + return DocIdSetIterator.NO_MORE_DOCS; + } + + scoreUpto = upto; + TermsEnum.SeekStatus status = termsEnum.seekCeil(terms.get(ords[upto++], spare), true); + if (status == TermsEnum.SeekStatus.FOUND) { + docsEnum = reuse = termsEnum.docs(acceptDocs, reuse, false); + } + } while (docsEnum == null); + + return docsEnum.nextDoc(); + } + + public int advance(int target) throws IOException { + int docId; + do { + docId = nextDoc(); + if (docId < target) { + int tempDocId = docsEnum.advance(target); + if (tempDocId == target) { + docId = tempDocId; + break; + } + } else if (docId == target) { + break; + } + docsEnum = null; // goto the next ord. + } while (docId != DocIdSetIterator.NO_MORE_DOCS); + return docId; + } + } + + // This impl that tracks whether a docid has already been emitted. This check makes sure that docs aren't emitted + // twice for different join values. This means that the first encountered join value determines the score of a document + // even if other join values yield a higher score. + class MVInnerScorer extends SVInnerScorer { + + final FixedBitSet alreadyEmittedDocs; + + MVInnerScorer(Weight weight, Bits acceptDocs, TermsEnum termsEnum, int maxDoc) { + super(weight, acceptDocs, termsEnum); + alreadyEmittedDocs = new FixedBitSet(maxDoc); + } + + public int nextDoc() throws IOException { + if (docsEnum != null) { + int docId; + do { + docId = docsEnum.nextDoc(); + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + break; + } + } while (alreadyEmittedDocs.get(docId)); + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + docsEnum = null; + } else { + alreadyEmittedDocs.set(docId); + return docId; + } + } + + for (;;) { + do { + if (upto == terms.size()) { + return DocIdSetIterator.NO_MORE_DOCS; + } + + scoreUpto = upto; + TermsEnum.SeekStatus status = termsEnum.seekCeil(terms.get(ords[upto++], spare), true); + if (status == TermsEnum.SeekStatus.FOUND) { + docsEnum = reuse = termsEnum.docs(acceptDocs, reuse, false); + } + } while (docsEnum == null); + + int docId; + do { + docId = docsEnum.nextDoc(); + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + break; + } + } while (alreadyEmittedDocs.get(docId)); + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + docsEnum = null; + } else { + alreadyEmittedDocs.set(docId); + return docId; + } + } + } + } + +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/TermsWithScoreCollector.java b/lucene/join/src/java/org/apache/lucene/search/join/TermsWithScoreCollector.java new file mode 100644 index 00000000000..af7c2debd64 --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/TermsWithScoreCollector.java @@ -0,0 +1,292 @@ +package org.apache.lucene.search.join; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.AtomicReaderContext; +import org.apache.lucene.index.DocTermOrds; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.FieldCache; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefHash; + +import java.io.IOException; + +abstract class TermsWithScoreCollector extends Collector { + + private final static int INITIAL_ARRAY_SIZE = 256; + + final String field; + final BytesRefHash collectedTerms = new BytesRefHash(); + final ScoreMode scoreMode; + + Scorer scorer; + float[] scoreSums = new float[INITIAL_ARRAY_SIZE]; + + TermsWithScoreCollector(String field, ScoreMode scoreMode) { + this.field = field; + this.scoreMode = scoreMode; + } + + public BytesRefHash getCollectedTerms() { + return collectedTerms; + } + + public float[] getScoresPerTerm() { + return scoreSums; + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + public boolean acceptsDocsOutOfOrder() { + return true; + } + + /** + * Chooses the right {@link TermsWithScoreCollector} implementation. + * + * @param field The field to collect terms for + * @param multipleValuesPerDocument Whether the field to collect terms for has multiple values per document. + * @return a {@link TermsWithScoreCollector} instance + */ + static TermsWithScoreCollector create(String field, boolean multipleValuesPerDocument, ScoreMode scoreMode) { + if (multipleValuesPerDocument) { + switch (scoreMode) { + case Avg: + return new MV.Avg(field); + default: + return new MV(field, scoreMode); + } + } else { + switch (scoreMode) { + case Avg: + return new SV.Avg(field); + default: + return new SV(field, scoreMode); + } + } + } + + // impl that works with single value per document + static class SV extends TermsWithScoreCollector { + + final BytesRef spare = new BytesRef(); + FieldCache.DocTerms fromDocTerms; + + SV(String field, ScoreMode scoreMode) { + super(field, scoreMode); + } + + public void collect(int doc) throws IOException { + int ord = collectedTerms.add(fromDocTerms.getTerm(doc, spare)); + if (ord < 0) { + ord = -ord - 1; + } else { + if (ord >= scoreSums.length) { + scoreSums = ArrayUtil.grow(scoreSums); + } + } + + float current = scorer.score(); + float existing = scoreSums[ord]; + if (Float.compare(existing, 0.0f) == 0) { + scoreSums[ord] = current; + } else { + switch (scoreMode) { + case Total: + scoreSums[ord] = scoreSums[ord] + current; + break; + case Max: + if (current > existing) { + scoreSums[ord] = current; + } + } + } + } + + public void setNextReader(AtomicReaderContext context) throws IOException { + fromDocTerms = FieldCache.DEFAULT.getTerms(context.reader(), field); + } + + static class Avg extends SV { + + int[] scoreCounts = new int[INITIAL_ARRAY_SIZE]; + + Avg(String field) { + super(field, ScoreMode.Avg); + } + + @Override + public void collect(int doc) throws IOException { + int ord = collectedTerms.add(fromDocTerms.getTerm(doc, spare)); + if (ord < 0) { + ord = -ord - 1; + } else { + if (ord >= scoreSums.length) { + scoreSums = ArrayUtil.grow(scoreSums); + scoreCounts = ArrayUtil.grow(scoreCounts); + } + } + + float current = scorer.score(); + float existing = scoreSums[ord]; + if (Float.compare(existing, 0.0f) == 0) { + scoreSums[ord] = current; + scoreCounts[ord] = 1; + } else { + scoreSums[ord] = scoreSums[ord] + current; + scoreCounts[ord]++; + } + } + + @Override + public float[] getScoresPerTerm() { + if (scoreCounts != null) { + for (int i = 0; i < scoreCounts.length; i++) { + scoreSums[i] = scoreSums[i] / scoreCounts[i]; + } + scoreCounts = null; + } + return scoreSums; + } + } + } + + // impl that works with multiple values per document + static class MV extends TermsWithScoreCollector { + + DocTermOrds fromDocTermOrds; + TermsEnum docTermsEnum; + DocTermOrds.TermOrdsIterator reuse; + + MV(String field, ScoreMode scoreMode) { + super(field, scoreMode); + } + + public void collect(int doc) throws IOException { + reuse = fromDocTermOrds.lookup(doc, reuse); + int[] buffer = new int[5]; + + int chunk; + do { + chunk = reuse.read(buffer); + if (chunk == 0) { + return; + } + + for (int idx = 0; idx < chunk; idx++) { + int key = buffer[idx]; + docTermsEnum.seekExact((long) key); + int ord = collectedTerms.add(docTermsEnum.term()); + if (ord < 0) { + ord = -ord - 1; + } else { + if (ord >= scoreSums.length) { + scoreSums = ArrayUtil.grow(scoreSums); + } + } + + final float current = scorer.score(); + final float existing = scoreSums[ord]; + if (Float.compare(existing, 0.0f) == 0) { + scoreSums[ord] = current; + } else { + switch (scoreMode) { + case Total: + scoreSums[ord] = existing + current; + break; + case Max: + if (current > existing) { + scoreSums[ord] = current; + } + } + } + } + } while (chunk >= buffer.length); + } + + public void setNextReader(AtomicReaderContext context) throws IOException { + fromDocTermOrds = FieldCache.DEFAULT.getDocTermOrds(context.reader(), field); + docTermsEnum = fromDocTermOrds.getOrdTermsEnum(context.reader()); + reuse = null; // LUCENE-3377 needs to be fixed first then this statement can be removed... + } + + static class Avg extends MV { + + int[] scoreCounts = new int[INITIAL_ARRAY_SIZE]; + + Avg(String field) { + super(field, ScoreMode.Avg); + } + + @Override + public void collect(int doc) throws IOException { + reuse = fromDocTermOrds.lookup(doc, reuse); + int[] buffer = new int[5]; + + int chunk; + do { + chunk = reuse.read(buffer); + if (chunk == 0) { + return; + } + + for (int idx = 0; idx < chunk; idx++) { + int key = buffer[idx]; + docTermsEnum.seekExact((long) key); + int ord = collectedTerms.add(docTermsEnum.term()); + if (ord < 0) { + ord = -ord - 1; + } else { + if (ord >= scoreSums.length) { + scoreSums = ArrayUtil.grow(scoreSums); + scoreCounts = ArrayUtil.grow(scoreCounts); + } + } + + float current = scorer.score(); + float existing = scoreSums[ord]; + if (Float.compare(existing, 0.0f) == 0) { + scoreSums[ord] = current; + scoreCounts[ord] = 1; + } else { + scoreSums[ord] = scoreSums[ord] + current; + scoreCounts[ord]++; + } + } + } while (chunk >= buffer.length); + } + + @Override + public float[] getScoresPerTerm() { + if (scoreCounts != null) { + for (int i = 0; i < scoreCounts.length; i++) { + scoreSums[i] = scoreSums[i] / scoreCounts[i]; + } + scoreCounts = null; + } + return scoreSums; + } + } + } + +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java index 2c22041c6d3..787c08b07ad 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java @@ -33,7 +33,6 @@ import org.apache.lucene.search.Filter; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.Scorer.ChildScorer; import org.apache.lucene.search.Weight; import org.apache.lucene.search.grouping.TopGroups; import org.apache.lucene.util.ArrayUtil; @@ -82,24 +81,8 @@ import org.apache.lucene.util.FixedBitSet; * * @lucene.experimental */ - public class ToParentBlockJoinQuery extends Query { - /** How to aggregate multiple child hit scores into a - * single parent score. */ - public static enum ScoreMode { - /** Do no scoring. */ - None, - /** Parent hit's score is the average of all child - scores. */ - Avg, - /** Parent hit's score is the max of all child - scores. */ - Max, - /** Parent hit's score is the sum of all child - scores. */ - Total}; - private final Filter parentsFilter; private final Query childQuery; diff --git a/lucene/join/src/java/org/apache/lucene/search/join/package.html b/lucene/join/src/java/org/apache/lucene/search/join/package.html index 5cea27a1c8a..036ef630ded 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/package.html +++ b/lucene/join/src/java/org/apache/lucene/search/join/package.html @@ -56,7 +56,7 @@ any query matching parent documents, creating the joined query matching only child documents. -

Search-time joins

+

Query-time joins

The query time joining is index term based and implemented as two pass search. The first pass collects all the terms from a fromField @@ -68,22 +68,26 @@

  • fromField: The from field to join from.
  • fromQuery: The query executed to collect the from terms. This is usually the user specified query.
  • multipleValuesPerDocument: Whether the fromField contains more than one value per document +
  • scoreMode: Defines how scores are translated to the other join side. If you don't care about scoring + use {@link org.apache.lucene.search.join.ScoreMode#None} mode. This will disable scoring and is therefore more + efficient (requires less memory and is faster).
  • toField: The to field to join to

    Basically the query-time joining is accessible from one static method. The user of this method supplies the method with the described input and a IndexSearcher where the from terms need to be collected from. The returned query can be executed with the same IndexSearcher, but also with another IndexSearcher. - Example usage of the {@link org.apache.lucene.search.join.JoinUtil#createJoinQuery(String, boolean, String, org.apache.lucene.search.Query, org.apache.lucene.search.IndexSearcher) + Example usage of the {@link org.apache.lucene.search.join.JoinUtil#createJoinQuery(String, boolean, String, org.apache.lucene.search.Query, org.apache.lucene.search.IndexSearcher, org.apache.lucene.search.join.ScoreMode) JoinUtil.createJoinQuery()} :

       String fromField = "from"; // Name of the from field
       boolean multipleValuesPerDocument = false; // Set only yo true in the case when your fromField has multiple values per document in your index
       String toField = "to"; // Name of the to field
    +  ScoreMode scoreMode = ScoreMode.Max // Defines how the scores are translated into the other side of the join.
       Query fromQuery = new TermQuery(new Term("content", searchTerm)); // Query executed to collect from values to join to the to values
     
    -  Query joinQuery = JoinUtil.createJoinQuery(fromField, multipleValuesPerDocument, toField, fromQuery, fromSearcher);
    +  Query joinQuery = JoinUtil.createJoinQuery(fromField, multipleValuesPerDocument, toField, fromQuery, fromSearcher, scoreMode);
       TopDocs topDocs = toSearcher.search(joinQuery, 10); // Note: toSearcher can be the same as the fromSearcher
       // Render topDocs...
     
    diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java b/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java index 34f7b656119..c81117702b9 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java @@ -96,7 +96,7 @@ public class TestBlockJoin extends LuceneTestCase { // Wrap the child document query to 'join' any matches // up to corresponding parent: - ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, ToParentBlockJoinQuery.ScoreMode.Avg); + ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, ScoreMode.Avg); // Combine the parent and nested child queries into a single query for a candidate BooleanQuery fullQuery = new BooleanQuery(); @@ -198,7 +198,7 @@ public class TestBlockJoin extends LuceneTestCase { // Wrap the child document query to 'join' any matches // up to corresponding parent: - ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, ToParentBlockJoinQuery.ScoreMode.Avg); + ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, ScoreMode.Avg); assertEquals("no filter - both passed", 2, s.search(childJoinQuery, 10).totalHits); @@ -259,7 +259,7 @@ public class TestBlockJoin extends LuceneTestCase { w.close(); IndexSearcher s = newSearcher(r); - ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(new MatchAllDocsQuery(), new QueryWrapperFilter(new MatchAllDocsQuery()), ToParentBlockJoinQuery.ScoreMode.Avg); + ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(new MatchAllDocsQuery(), new QueryWrapperFilter(new MatchAllDocsQuery()), ScoreMode.Avg); s.search(q, 10); BooleanQuery bq = new BooleanQuery(); bq.setBoost(2f); // we boost the BQ @@ -493,15 +493,15 @@ public class TestBlockJoin extends LuceneTestCase { } final int x = random().nextInt(4); - final ToParentBlockJoinQuery.ScoreMode agg; + final ScoreMode agg; if (x == 0) { - agg = ToParentBlockJoinQuery.ScoreMode.None; + agg = ScoreMode.None; } else if (x == 1) { - agg = ToParentBlockJoinQuery.ScoreMode.Max; + agg = ScoreMode.Max; } else if (x == 2) { - agg = ToParentBlockJoinQuery.ScoreMode.Total; + agg = ScoreMode.Total; } else { - agg = ToParentBlockJoinQuery.ScoreMode.Avg; + agg = ScoreMode.Avg; } final ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, agg); @@ -584,7 +584,7 @@ public class TestBlockJoin extends LuceneTestCase { final boolean trackScores; final boolean trackMaxScore; - if (agg == ToParentBlockJoinQuery.ScoreMode.None) { + if (agg == ScoreMode.None) { trackScores = false; trackMaxScore = false; } else { @@ -881,8 +881,8 @@ public class TestBlockJoin extends LuceneTestCase { // Wrap the child document query to 'join' any matches // up to corresponding parent: - ToParentBlockJoinQuery childJobJoinQuery = new ToParentBlockJoinQuery(childJobQuery, parentsFilter, ToParentBlockJoinQuery.ScoreMode.Avg); - ToParentBlockJoinQuery childQualificationJoinQuery = new ToParentBlockJoinQuery(childQualificationQuery, parentsFilter, ToParentBlockJoinQuery.ScoreMode.Avg); + ToParentBlockJoinQuery childJobJoinQuery = new ToParentBlockJoinQuery(childJobQuery, parentsFilter, ScoreMode.Avg); + ToParentBlockJoinQuery childQualificationJoinQuery = new ToParentBlockJoinQuery(childQualificationQuery, parentsFilter, ScoreMode.Avg); // Combine the parent and nested child queries into a single query for a candidate BooleanQuery fullQuery = new BooleanQuery(); @@ -952,7 +952,7 @@ public class TestBlockJoin extends LuceneTestCase { new QueryWrapperFilter( new TermQuery(new Term("parent", "1")))); - ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(tq, parentFilter, ToParentBlockJoinQuery.ScoreMode.Avg); + ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(tq, parentFilter, ScoreMode.Avg); Weight weight = s.createNormalizedWeight(q); DocIdSetIterator disi = weight.scorer(s.getIndexReader().getTopReaderContext().leaves()[0], true, true, null); assertEquals(1, disi.advance(1)); @@ -986,7 +986,7 @@ public class TestBlockJoin extends LuceneTestCase { new QueryWrapperFilter( new TermQuery(new Term("isparent", "yes")))); - ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(tq, parentFilter, ToParentBlockJoinQuery.ScoreMode.Avg); + ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(tq, parentFilter, ScoreMode.Avg); Weight weight = s.createNormalizedWeight(q); DocIdSetIterator disi = weight.scorer(s.getIndexReader().getTopReaderContext().leaves()[0], true, true, null); assertEquals(2, disi.advance(0)); diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java b/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java index 8040ded1176..7c7ca12c668 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java @@ -22,8 +22,26 @@ import org.apache.lucene.analysis.MockTokenizer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.TextField; -import org.apache.lucene.index.*; -import org.apache.lucene.search.*; +import org.apache.lucene.index.AtomicReaderContext; +import org.apache.lucene.index.DocTermOrds; +import org.apache.lucene.index.DocsEnum; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.MultiFields; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.FieldCache; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; @@ -49,45 +67,45 @@ public class TestJoinUtil extends LuceneTestCase { // 0 Document doc = new Document(); - doc.add(new Field("description", "random text", TextField.TYPE_STORED)); - doc.add(new Field("name", "name1", TextField.TYPE_STORED)); - doc.add(new Field(idField, "1", TextField.TYPE_STORED)); + doc.add(new Field("description", "random text", TextField.TYPE_UNSTORED)); + doc.add(new Field("name", "name1", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "1", TextField.TYPE_UNSTORED)); w.addDocument(doc); // 1 doc = new Document(); - doc.add(new Field("price", "10.0", TextField.TYPE_STORED)); - doc.add(new Field(idField, "2", TextField.TYPE_STORED)); - doc.add(new Field(toField, "1", TextField.TYPE_STORED)); + doc.add(new Field("price", "10.0", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "2", TextField.TYPE_UNSTORED)); + doc.add(new Field(toField, "1", TextField.TYPE_UNSTORED)); w.addDocument(doc); // 2 doc = new Document(); - doc.add(new Field("price", "20.0", TextField.TYPE_STORED)); - doc.add(new Field(idField, "3", TextField.TYPE_STORED)); - doc.add(new Field(toField, "1", TextField.TYPE_STORED)); + doc.add(new Field("price", "20.0", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "3", TextField.TYPE_UNSTORED)); + doc.add(new Field(toField, "1", TextField.TYPE_UNSTORED)); w.addDocument(doc); // 3 doc = new Document(); - doc.add(new Field("description", "more random text", TextField.TYPE_STORED)); - doc.add(new Field("name", "name2", TextField.TYPE_STORED)); - doc.add(new Field(idField, "4", TextField.TYPE_STORED)); + doc.add(new Field("description", "more random text", TextField.TYPE_UNSTORED)); + doc.add(new Field("name", "name2", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "4", TextField.TYPE_UNSTORED)); w.addDocument(doc); w.commit(); // 4 doc = new Document(); - doc.add(new Field("price", "10.0", TextField.TYPE_STORED)); - doc.add(new Field(idField, "5", TextField.TYPE_STORED)); - doc.add(new Field(toField, "4", TextField.TYPE_STORED)); + doc.add(new Field("price", "10.0", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "5", TextField.TYPE_UNSTORED)); + doc.add(new Field(toField, "4", TextField.TYPE_UNSTORED)); w.addDocument(doc); // 5 doc = new Document(); - doc.add(new Field("price", "20.0", TextField.TYPE_STORED)); - doc.add(new Field(idField, "6", TextField.TYPE_STORED)); - doc.add(new Field(toField, "4", TextField.TYPE_STORED)); + doc.add(new Field("price", "20.0", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "6", TextField.TYPE_UNSTORED)); + doc.add(new Field(toField, "4", TextField.TYPE_UNSTORED)); w.addDocument(doc); IndexSearcher indexSearcher = new IndexSearcher(w.getReader()); @@ -95,21 +113,21 @@ public class TestJoinUtil extends LuceneTestCase { // Search for product Query joinQuery = - JoinUtil.createJoinQuery(idField, false, toField, new TermQuery(new Term("name", "name2")), indexSearcher); + JoinUtil.createJoinQuery(idField, false, toField, new TermQuery(new Term("name", "name2")), indexSearcher, ScoreMode.None); TopDocs result = indexSearcher.search(joinQuery, 10); assertEquals(2, result.totalHits); assertEquals(4, result.scoreDocs[0].doc); assertEquals(5, result.scoreDocs[1].doc); - joinQuery = JoinUtil.createJoinQuery(idField, false, toField, new TermQuery(new Term("name", "name1")), indexSearcher); + joinQuery = JoinUtil.createJoinQuery(idField, false, toField, new TermQuery(new Term("name", "name1")), indexSearcher, ScoreMode.None); result = indexSearcher.search(joinQuery, 10); assertEquals(2, result.totalHits); assertEquals(1, result.scoreDocs[0].doc); assertEquals(2, result.scoreDocs[1].doc); // Search for offer - joinQuery = JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("id", "5")), indexSearcher); + joinQuery = JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("id", "5")), indexSearcher, ScoreMode.None); result = indexSearcher.search(joinQuery, 10); assertEquals(1, result.totalHits); assertEquals(3, result.scoreDocs[0].doc); @@ -118,6 +136,96 @@ public class TestJoinUtil extends LuceneTestCase { dir.close(); } + public void testSimpleWithScoring() throws Exception { + final String idField = "id"; + final String toField = "movieId"; + + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter( + random(), + dir, + newIndexWriterConfig(TEST_VERSION_CURRENT, + new MockAnalyzer(random())).setMergePolicy(newLogMergePolicy())); + + // 0 + Document doc = new Document(); + doc.add(new Field("description", "A random movie", TextField.TYPE_UNSTORED)); + doc.add(new Field("name", "Movie 1", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "1", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 1 + doc = new Document(); + doc.add(new Field("subtitle", "The first subtitle of this movie", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "2", TextField.TYPE_UNSTORED)); + doc.add(new Field(toField, "1", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 2 + doc = new Document(); + doc.add(new Field("subtitle", "random subtitle; random event movie", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "3", TextField.TYPE_UNSTORED)); + doc.add(new Field(toField, "1", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 3 + doc = new Document(); + doc.add(new Field("description", "A second random movie", TextField.TYPE_UNSTORED)); + doc.add(new Field("name", "Movie 2", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "4", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + w.commit(); + + // 4 + doc = new Document(); + doc.add(new Field("subtitle", "a very random event happened during christmas night", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "5", TextField.TYPE_UNSTORED)); + doc.add(new Field(toField, "4", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 5 + doc = new Document(); + doc.add(new Field("subtitle", "movie end movie test 123 test 123 random", TextField.TYPE_UNSTORED)); + doc.add(new Field(idField, "6", TextField.TYPE_UNSTORED)); + doc.add(new Field(toField, "4", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + IndexSearcher indexSearcher = new IndexSearcher(w.getReader()); + w.close(); + + // Search for movie via subtitle + Query joinQuery = + JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("subtitle", "random")), indexSearcher, ScoreMode.Max); + TopDocs result = indexSearcher.search(joinQuery, 10); + assertEquals(2, result.totalHits); + assertEquals(0, result.scoreDocs[0].doc); + assertEquals(3, result.scoreDocs[1].doc); + + // Score mode max. + joinQuery = JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("subtitle", "movie")), indexSearcher, ScoreMode.Max); + result = indexSearcher.search(joinQuery, 10); + assertEquals(2, result.totalHits); + assertEquals(3, result.scoreDocs[0].doc); + assertEquals(0, result.scoreDocs[1].doc); + + // Score mode total + joinQuery = JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("subtitle", "movie")), indexSearcher, ScoreMode.Total); + result = indexSearcher.search(joinQuery, 10); + assertEquals(2, result.totalHits); + assertEquals(0, result.scoreDocs[0].doc); + assertEquals(3, result.scoreDocs[1].doc); + + //Score mode avg + joinQuery = JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("subtitle", "movie")), indexSearcher, ScoreMode.Avg); + result = indexSearcher.search(joinQuery, 10); + assertEquals(2, result.totalHits); + assertEquals(3, result.scoreDocs[0].doc); + assertEquals(0, result.scoreDocs[1].doc); + + indexSearcher.getIndexReader().close(); + dir.close(); + } + @Test public void testSingleValueRandomJoin() throws Exception { int maxIndexIter = _TestUtil.nextInt(random(), 6, 12); @@ -160,15 +268,20 @@ public class TestJoinUtil extends LuceneTestCase { String randomValue = context.randomUniqueValues[r]; FixedBitSet expectedResult = createExpectedResult(randomValue, from, indexSearcher.getIndexReader(), context); - Query actualQuery = new TermQuery(new Term("value", randomValue)); + final Query actualQuery = new TermQuery(new Term("value", randomValue)); if (VERBOSE) { System.out.println("actualQuery=" + actualQuery); } - Query joinQuery; + final ScoreMode scoreMode = ScoreMode.values()[random().nextInt(ScoreMode.values().length)]; + if (VERBOSE) { + System.out.println("scoreMode=" + scoreMode); + } + + final Query joinQuery; if (from) { - joinQuery = JoinUtil.createJoinQuery("from", multipleValuesPerDocument, "to", actualQuery, indexSearcher); + joinQuery = JoinUtil.createJoinQuery("from", multipleValuesPerDocument, "to", actualQuery, indexSearcher, scoreMode); } else { - joinQuery = JoinUtil.createJoinQuery("to", multipleValuesPerDocument, "from", actualQuery, indexSearcher); + joinQuery = JoinUtil.createJoinQuery("to", multipleValuesPerDocument, "from", actualQuery, indexSearcher, scoreMode); } if (VERBOSE) { System.out.println("joinQuery=" + joinQuery); @@ -176,26 +289,30 @@ public class TestJoinUtil extends LuceneTestCase { // Need to know all documents that have matches. TopDocs doesn't give me that and then I'd be also testing TopDocsCollector... final FixedBitSet actualResult = new FixedBitSet(indexSearcher.getIndexReader().maxDoc()); + final TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(10, false); indexSearcher.search(joinQuery, new Collector() { int docBase; public void collect(int doc) throws IOException { actualResult.set(doc + docBase); + topScoreDocCollector.collect(doc); } public void setNextReader(AtomicReaderContext context) throws IOException { docBase = context.docBase; + topScoreDocCollector.setNextReader(context); } public void setScorer(Scorer scorer) throws IOException { + topScoreDocCollector.setScorer(scorer); } public boolean acceptsDocsOutOfOrder() { - return true; + return topScoreDocCollector.acceptsDocsOutOfOrder(); } }); - + // Asserting bit set... if (VERBOSE) { System.out.println("expected cardinality:" + expectedResult.cardinality()); DocIdSetIterator iterator = expectedResult.iterator(); @@ -208,8 +325,28 @@ public class TestJoinUtil extends LuceneTestCase { System.out.println(String.format("Actual doc[%d] with id value %s", doc, indexSearcher.doc(doc).get("id"))); } } - assertEquals(expectedResult, actualResult); + + // Asserting TopDocs... + TopDocs expectedTopDocs = createExpectedTopDocs(randomValue, from, scoreMode, context); + TopDocs actualTopDocs = topScoreDocCollector.topDocs(); + assertEquals(expectedTopDocs.totalHits, actualTopDocs.totalHits); + assertEquals(expectedTopDocs.scoreDocs.length, actualTopDocs.scoreDocs.length); + if (scoreMode == ScoreMode.None) { + continue; + } + + assertEquals(expectedTopDocs.getMaxScore(), actualTopDocs.getMaxScore(), 0.0f); + for (int i = 0; i < expectedTopDocs.scoreDocs.length; i++) { + if (VERBOSE) { + System.out.printf("Expected doc: %d | Actual doc: %d\n", expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc); + System.out.printf("Expected score: %f | Actual score: %f\n", expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score); + } + assertEquals(expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc); + assertEquals(expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score, 0.0f); + Explanation explanation = indexSearcher.explain(joinQuery, expectedTopDocs.scoreDocs[i].doc); + assertEquals(expectedTopDocs.scoreDocs[i].score, explanation.getValue(), 0.0f); + } } topLevelReader.close(); dir.close(); @@ -238,20 +375,21 @@ public class TestJoinUtil extends LuceneTestCase { context.randomUniqueValues[i] = uniqueRandomValue; } + RandomDoc[] docs = new RandomDoc[nDocs]; for (int i = 0; i < nDocs; i++) { String id = Integer.toString(i); int randomI = random().nextInt(context.randomUniqueValues.length); String value = context.randomUniqueValues[randomI]; Document document = new Document(); - document.add(newField(random(), "id", id, TextField.TYPE_STORED)); - document.add(newField(random(), "value", value, TextField.TYPE_STORED)); + document.add(newField(random(), "id", id, TextField.TYPE_UNSTORED)); + document.add(newField(random(), "value", value, TextField.TYPE_UNSTORED)); boolean from = context.randomFrom[randomI]; int numberOfLinkValues = multipleValuesPerDocument ? 2 + random().nextInt(10) : 1; - RandomDoc doc = new RandomDoc(id, numberOfLinkValues, value); + docs[i] = new RandomDoc(id, numberOfLinkValues, value, from); for (int j = 0; j < numberOfLinkValues; j++) { String linkValue = context.randomUniqueValues[random().nextInt(context.randomUniqueValues.length)]; - doc.linkValues.add(linkValue); + docs[i].linkValues.add(linkValue); if (from) { if (!context.fromDocuments.containsKey(linkValue)) { context.fromDocuments.put(linkValue, new ArrayList()); @@ -260,9 +398,9 @@ public class TestJoinUtil extends LuceneTestCase { context.randomValueFromDocs.put(value, new ArrayList()); } - context.fromDocuments.get(linkValue).add(doc); - context.randomValueFromDocs.get(value).add(doc); - document.add(newField(random(), "from", linkValue, TextField.TYPE_STORED)); + context.fromDocuments.get(linkValue).add(docs[i]); + context.randomValueFromDocs.get(value).add(docs[i]); + document.add(newField(random(), "from", linkValue, TextField.TYPE_UNSTORED)); } else { if (!context.toDocuments.containsKey(linkValue)) { context.toDocuments.put(linkValue, new ArrayList()); @@ -271,9 +409,9 @@ public class TestJoinUtil extends LuceneTestCase { context.randomValueToDocs.put(value, new ArrayList()); } - context.toDocuments.get(linkValue).add(doc); - context.randomValueToDocs.get(value).add(doc); - document.add(newField(random(), "to", linkValue, TextField.TYPE_STORED)); + context.toDocuments.get(linkValue).add(docs[i]); + context.randomValueToDocs.get(value).add(docs[i]); + document.add(newField(random(), "to", linkValue, TextField.TYPE_UNSTORED)); } } @@ -289,12 +427,235 @@ public class TestJoinUtil extends LuceneTestCase { w.commit(); } if (VERBOSE) { - System.out.println("Added document[" + i + "]: " + document); + System.out.println("Added document[" + docs[i].id + "]: " + document); } } + + // Pre-compute all possible hits for all unique random values. On top of this also compute all possible score for + // any ScoreMode. + IndexSearcher fromSearcher = newSearcher(fromWriter.getReader()); + IndexSearcher toSearcher = newSearcher(toWriter.getReader()); + for (int i = 0; i < context.randomUniqueValues.length; i++) { + String uniqueRandomValue = context.randomUniqueValues[i]; + final String fromField; + final String toField; + final Map> queryVals; + if (context.randomFrom[i]) { + fromField = "from"; + toField = "to"; + queryVals = context.fromHitsToJoinScore; + } else { + fromField = "to"; + toField = "from"; + queryVals = context.toHitsToJoinScore; + } + final Map joinValueToJoinScores = new HashMap(); + if (multipleValuesPerDocument) { + fromSearcher.search(new TermQuery(new Term("value", uniqueRandomValue)), new Collector() { + + private Scorer scorer; + private DocTermOrds docTermOrds; + private TermsEnum docTermsEnum; + private DocTermOrds.TermOrdsIterator reuse; + + public void collect(int doc) throws IOException { + if (docTermOrds.isEmpty()) { + return; + } + + reuse = docTermOrds.lookup(doc, reuse); + int[] buffer = new int[5]; + + int chunk; + do { + chunk = reuse.read(buffer); + if (chunk == 0) { + return; + } + + for (int idx = 0; idx < chunk; idx++) { + int key = buffer[idx]; + docTermsEnum.seekExact((long) key); + BytesRef joinValue = docTermsEnum.term(); + if (joinValue == null) { + continue; + } + + JoinScore joinScore = joinValueToJoinScores.get(joinValue); + if (joinScore == null) { + joinValueToJoinScores.put(BytesRef.deepCopyOf(joinValue), joinScore = new JoinScore()); + } + joinScore.addScore(scorer.score()); + } + } while (chunk >= buffer.length); + } + + public void setNextReader(AtomicReaderContext context) throws IOException { + docTermOrds = FieldCache.DEFAULT.getDocTermOrds(context.reader(), fromField); + docTermsEnum = docTermOrds.getOrdTermsEnum(context.reader()); + reuse = null; + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + public boolean acceptsDocsOutOfOrder() { + return false; + } + }); + } else { + fromSearcher.search(new TermQuery(new Term("value", uniqueRandomValue)), new Collector() { + + private Scorer scorer; + private FieldCache.DocTerms terms; + private final BytesRef spare = new BytesRef(); + + public void collect(int doc) throws IOException { + BytesRef joinValue = terms.getTerm(doc, spare); + if (joinValue == null) { + return; + } + + JoinScore joinScore = joinValueToJoinScores.get(joinValue); + if (joinScore == null) { + joinValueToJoinScores.put(BytesRef.deepCopyOf(joinValue), joinScore = new JoinScore()); + } + joinScore.addScore(scorer.score()); + } + + public void setNextReader(AtomicReaderContext context) throws IOException { + terms = FieldCache.DEFAULT.getTerms(context.reader(), fromField); + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + public boolean acceptsDocsOutOfOrder() { + return false; + } + }); + } + + final Map docToJoinScore = new HashMap(); + if (multipleValuesPerDocument) { + toSearcher.search(new MatchAllDocsQuery(), new Collector() { + + private DocTermOrds docTermOrds; + private TermsEnum docTermsEnum; + private DocTermOrds.TermOrdsIterator reuse; + private int docBase; + + public void collect(int doc) throws IOException { + if (docTermOrds.isEmpty()) { + return; + } + + reuse = docTermOrds.lookup(doc, reuse); + int[] buffer = new int[5]; + + int chunk; + do { + chunk = reuse.read(buffer); + if (chunk == 0) { + return; + } + + for (int idx = 0; idx < chunk; idx++) { + int key = buffer[idx]; + docTermsEnum.seekExact((long) key); + JoinScore joinScore = joinValueToJoinScores.get(docTermsEnum.term()); + if (joinScore == null) { + continue; + } + Integer basedDoc = docBase + doc; + // First encountered join value determines the score. + // Something to keep in mind for many-to-many relations. + if (!docToJoinScore.containsKey(basedDoc)) { + docToJoinScore.put(basedDoc, joinScore); + } + } + } while (chunk >= buffer.length); + } + + public void setNextReader(AtomicReaderContext context) throws IOException { + docBase = context.docBase; + docTermOrds = FieldCache.DEFAULT.getDocTermOrds(context.reader(), toField); + docTermsEnum = docTermOrds.getOrdTermsEnum(context.reader()); + reuse = null; + } + + public boolean acceptsDocsOutOfOrder() {return false;} + public void setScorer(Scorer scorer) throws IOException {} + }); + } else { + toSearcher.search(new MatchAllDocsQuery(), new Collector() { + + private FieldCache.DocTerms terms; + private int docBase; + private final BytesRef spare = new BytesRef(); + + public void collect(int doc) throws IOException { + JoinScore joinScore = joinValueToJoinScores.get(terms.getTerm(doc, spare)); + if (joinScore == null) { + return; + } + docToJoinScore.put(docBase + doc, joinScore); + } + + public void setNextReader(AtomicReaderContext context) throws IOException { + terms = FieldCache.DEFAULT.getTerms(context.reader(), toField); + docBase = context.docBase; + } + + public boolean acceptsDocsOutOfOrder() {return false;} + public void setScorer(Scorer scorer) throws IOException {} + }); + } + queryVals.put(uniqueRandomValue, docToJoinScore); + } + + fromSearcher.getIndexReader().close(); + toSearcher.getIndexReader().close(); + return context; } + private TopDocs createExpectedTopDocs(String queryValue, + final boolean from, + final ScoreMode scoreMode, + IndexIterationContext context) throws IOException { + + Map hitsToJoinScores; + if (from) { + hitsToJoinScores = context.fromHitsToJoinScore.get(queryValue); + } else { + hitsToJoinScores = context.toHitsToJoinScore.get(queryValue); + } + List> hits = new ArrayList>(hitsToJoinScores.entrySet()); + Collections.sort(hits, new Comparator>() { + + public int compare(Map.Entry hit1, Map.Entry hit2) { + float score1 = hit1.getValue().score(scoreMode); + float score2 = hit2.getValue().score(scoreMode); + + int cmp = Float.compare(score2, score1); + if (cmp != 0) { + return cmp; + } + return hit1.getKey() - hit2.getKey(); + } + + }); + ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(10, hits.size())]; + for (int i = 0; i < scoreDocs.length; i++) { + Map.Entry hit = hits.get(i); + scoreDocs[i] = new ScoreDoc(hit.getKey(), hit.getValue().score(scoreMode)); + } + return new TopDocs(hits.size(), scoreDocs, hits.isEmpty() ? Float.NaN : hits.get(0).getValue().score(scoreMode)); + } + private FixedBitSet createExpectedResult(String queryValue, boolean from, IndexReader topLevelReader, IndexIterationContext context) throws IOException { final Map> randomValueDocs; final Map> linkValueDocuments; @@ -339,6 +700,9 @@ public class TestJoinUtil extends LuceneTestCase { Map> randomValueFromDocs = new HashMap>(); Map> randomValueToDocs = new HashMap>(); + Map> fromHitsToJoinScore = new HashMap>(); + Map> toHitsToJoinScore = new HashMap>(); + } private static class RandomDoc { @@ -346,12 +710,44 @@ public class TestJoinUtil extends LuceneTestCase { final String id; final List linkValues; final String value; + final boolean from; - private RandomDoc(String id, int numberOfLinkValues, String value) { + private RandomDoc(String id, int numberOfLinkValues, String value, boolean from) { this.id = id; + this.from = from; linkValues = new ArrayList(numberOfLinkValues); this.value = value; } } + private static class JoinScore { + + float maxScore; + float total; + int count; + + void addScore(float score) { + total += score; + if (score > maxScore) { + maxScore = score; + } + count++; + } + + float score(ScoreMode mode) { + switch (mode) { + case None: + return 1.0f; + case Total: + return total; + case Avg: + return total / count; + case Max: + return maxScore; + } + throw new IllegalArgumentException("Unsupported ScoreMode: " + mode); + } + + } + }