LUCENE-4043: Added scoring support via score mode for query time joining.

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1343966 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Martijn van Groningen 2012-05-29 20:37:31 +00:00
parent 5d3dba2c56
commit 36acada762
10 changed files with 1111 additions and 79 deletions

View File

@ -880,6 +880,9 @@ New features
returning results after a specified FieldDoc for deep
paging. (Mike McCandless)
* LUCENE-4043: Added scoring support via score mode for query time joining.
(Martijn van Groningen, Mike McCandless)
Optimizations
* LUCENE-2588: Don't store unnecessary suffixes when writing the terms

View File

@ -220,6 +220,13 @@ public class DocTermOrds {
return numTermsInField;
}
/**
* @return Whether this <code>DocTermOrds</code> instance is empty.
*/
public boolean isEmpty() {
return index == null;
}
/** Subclass can override this */
protected void visitTerm(TermsEnum te, int termNum) throws IOException {
}

View File

@ -38,12 +38,24 @@ public final class JoinUtil {
* <p/>
* Execute the returned query with a {@link IndexSearcher} to retrieve all documents that have the same terms in the
* to field that match with documents matching the specified fromQuery and have the same terms in the from field.
* <p/>
* In the case a single document relates to more than one document the <code>multipleValuesPerDocument</code> option
* should be set to true. When the <code>multipleValuesPerDocument</code> is set to <code>true</code> only the
* the score from the first encountered join value originating from the 'from' side is mapped into the 'to' side.
* Even in the case when a second join value related to a specific document yields a higher score. Obviously this
* doesn't apply in the case that {@link ScoreMode#None} is used, since no scores are computed at all.
* </p>
* Memory considerations: During joining all unique join values are kept in memory. On top of that when the scoreMode
* isn't set to {@link ScoreMode#None} a float value per unique join value is kept in memory for computing scores.
* When scoreMode is set to {@link ScoreMode#Avg} also an additional integer value is kept in memory per unique
* join value.
*
* @param fromField The from field to join from
* @param multipleValuesPerDocument Whether the from field has multiple terms per document
* @param toField The to field to join to
* @param fromQuery The query to match documents on the from side
* @param fromSearcher The searcher that executed the specified fromQuery
* @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
* @return a {@link Query} instance that can be used to join documents based on the
* terms in the from and to field
* @throws IOException If I/O related errors occur
@ -52,10 +64,29 @@ public final class JoinUtil {
boolean multipleValuesPerDocument,
String toField,
Query fromQuery,
IndexSearcher fromSearcher) throws IOException {
IndexSearcher fromSearcher,
ScoreMode scoreMode) throws IOException {
switch (scoreMode) {
case None:
TermsCollector termsCollector = TermsCollector.create(fromField, multipleValuesPerDocument);
fromSearcher.search(fromQuery, termsCollector);
return new TermsQuery(toField, termsCollector.getCollectorTerms());
case Total:
case Max:
case Avg:
TermsWithScoreCollector termsWithScoreCollector =
TermsWithScoreCollector.create(fromField, multipleValuesPerDocument, scoreMode);
fromSearcher.search(fromQuery, termsWithScoreCollector);
return new TermsIncludingScoreQuery(
toField,
multipleValuesPerDocument,
termsWithScoreCollector.getCollectedTerms(),
termsWithScoreCollector.getScoresPerTerm(),
fromQuery
);
default:
throw new IllegalArgumentException(String.format("Score mode %s isn't supported.", scoreMode));
}
}
}

View File

@ -0,0 +1,45 @@
package org.apache.lucene.search.join;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* How to aggregate multiple child hit scores into a single parent score.
*/
public enum ScoreMode {
/**
* Do no scoring.
*/
None,
/**
* Parent hit's score is the average of all child scores.
*/
Avg,
/**
* Parent hit's score is the max of all child scores.
*/
Max,
/**
* Parent hit's score is the sum of all child scores.
*/
Total
}

View File

@ -0,0 +1,271 @@
package org.apache.lucene.search.join;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.index.DocsEnum;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.ComplexExplanation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefHash;
import org.apache.lucene.util.FixedBitSet;
import java.io.IOException;
import java.util.Set;
class TermsIncludingScoreQuery extends Query {
final String field;
final boolean multipleValuesPerDocument;
final BytesRefHash terms;
final float[] scores;
final int[] ords;
final Query originalQuery;
final Query unwrittenOriginalQuery;
TermsIncludingScoreQuery(String field, boolean multipleValuesPerDocument, BytesRefHash terms, float[] scores, Query originalQuery) {
this.field = field;
this.multipleValuesPerDocument = multipleValuesPerDocument;
this.terms = terms;
this.scores = scores;
this.originalQuery = originalQuery;
this.ords = terms.sort(BytesRef.getUTF8SortedAsUnicodeComparator());
this.unwrittenOriginalQuery = originalQuery;
}
private TermsIncludingScoreQuery(String field, boolean multipleValuesPerDocument, BytesRefHash terms, float[] scores, int[] ords, Query originalQuery, Query unwrittenOriginalQuery) {
this.field = field;
this.multipleValuesPerDocument = multipleValuesPerDocument;
this.terms = terms;
this.scores = scores;
this.originalQuery = originalQuery;
this.ords = ords;
this.unwrittenOriginalQuery = unwrittenOriginalQuery;
}
public String toString(String string) {
return String.format("TermsIncludingScoreQuery{field=%s;originalQuery=%s}", field, unwrittenOriginalQuery);
}
@Override
public void extractTerms(Set<Term> terms) {
originalQuery.extractTerms(terms);
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
final Query originalQueryRewrite = originalQuery.rewrite(reader);
if (originalQueryRewrite != originalQuery) {
Query rewritten = new TermsIncludingScoreQuery(field, multipleValuesPerDocument, terms, scores,
ords, originalQueryRewrite, originalQuery);
rewritten.setBoost(getBoost());
return rewritten;
} else {
return this;
}
}
@Override
public Weight createWeight(IndexSearcher searcher) throws IOException {
final Weight originalWeight = originalQuery.createWeight(searcher);
return new Weight() {
private TermsEnum segmentTermsEnum;
public Explanation explain(AtomicReaderContext context, int doc) throws IOException {
SVInnerScorer scorer = (SVInnerScorer) scorer(context, true, false, context.reader().getLiveDocs());
if (scorer != null) {
if (scorer.advance(doc) == doc) {
return scorer.explain();
}
}
return new ComplexExplanation(false, 0.0f, "Not a match");
}
public Query getQuery() {
return TermsIncludingScoreQuery.this;
}
public float getValueForNormalization() throws IOException {
return originalWeight.getValueForNormalization() * TermsIncludingScoreQuery.this.getBoost() * TermsIncludingScoreQuery.this.getBoost();
}
public void normalize(float norm, float topLevelBoost) {
originalWeight.normalize(norm, topLevelBoost * TermsIncludingScoreQuery.this.getBoost());
}
public Scorer scorer(AtomicReaderContext context, boolean scoreDocsInOrder, boolean topScorer, Bits acceptDocs) throws IOException {
Terms terms = context.reader().terms(field);
if (terms == null) {
return null;
}
segmentTermsEnum = terms.iterator(segmentTermsEnum);
if (multipleValuesPerDocument) {
return new MVInnerScorer(this, acceptDocs, segmentTermsEnum, context.reader().maxDoc());
} else {
return new SVInnerScorer(this, acceptDocs, segmentTermsEnum);
}
}
};
}
// This impl assumes that the 'join' values are used uniquely per doc per field. Used for one to many relations.
class SVInnerScorer extends Scorer {
final BytesRef spare = new BytesRef();
final Bits acceptDocs;
final TermsEnum termsEnum;
int upto;
DocsEnum docsEnum;
DocsEnum reuse;
int scoreUpto;
SVInnerScorer(Weight weight, Bits acceptDocs, TermsEnum termsEnum) {
super(weight);
this.acceptDocs = acceptDocs;
this.termsEnum = termsEnum;
}
public float score() throws IOException {
return scores[ords[scoreUpto]];
}
public Explanation explain() throws IOException {
return new ComplexExplanation(true, score(), "Score based on join value " + termsEnum.term().utf8ToString());
}
public int docID() {
return docsEnum != null ? docsEnum.docID() : DocIdSetIterator.NO_MORE_DOCS;
}
public int nextDoc() throws IOException {
if (docsEnum != null) {
int docId = docsEnum.nextDoc();
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
docsEnum = null;
} else {
return docId;
}
}
do {
if (upto == terms.size()) {
return DocIdSetIterator.NO_MORE_DOCS;
}
scoreUpto = upto;
TermsEnum.SeekStatus status = termsEnum.seekCeil(terms.get(ords[upto++], spare), true);
if (status == TermsEnum.SeekStatus.FOUND) {
docsEnum = reuse = termsEnum.docs(acceptDocs, reuse, false);
}
} while (docsEnum == null);
return docsEnum.nextDoc();
}
public int advance(int target) throws IOException {
int docId;
do {
docId = nextDoc();
if (docId < target) {
int tempDocId = docsEnum.advance(target);
if (tempDocId == target) {
docId = tempDocId;
break;
}
} else if (docId == target) {
break;
}
docsEnum = null; // goto the next ord.
} while (docId != DocIdSetIterator.NO_MORE_DOCS);
return docId;
}
}
// This impl that tracks whether a docid has already been emitted. This check makes sure that docs aren't emitted
// twice for different join values. This means that the first encountered join value determines the score of a document
// even if other join values yield a higher score.
class MVInnerScorer extends SVInnerScorer {
final FixedBitSet alreadyEmittedDocs;
MVInnerScorer(Weight weight, Bits acceptDocs, TermsEnum termsEnum, int maxDoc) {
super(weight, acceptDocs, termsEnum);
alreadyEmittedDocs = new FixedBitSet(maxDoc);
}
public int nextDoc() throws IOException {
if (docsEnum != null) {
int docId;
do {
docId = docsEnum.nextDoc();
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
} while (alreadyEmittedDocs.get(docId));
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
docsEnum = null;
} else {
alreadyEmittedDocs.set(docId);
return docId;
}
}
for (;;) {
do {
if (upto == terms.size()) {
return DocIdSetIterator.NO_MORE_DOCS;
}
scoreUpto = upto;
TermsEnum.SeekStatus status = termsEnum.seekCeil(terms.get(ords[upto++], spare), true);
if (status == TermsEnum.SeekStatus.FOUND) {
docsEnum = reuse = termsEnum.docs(acceptDocs, reuse, false);
}
} while (docsEnum == null);
int docId;
do {
docId = docsEnum.nextDoc();
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
} while (alreadyEmittedDocs.get(docId));
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
docsEnum = null;
} else {
alreadyEmittedDocs.set(docId);
return docId;
}
}
}
}
}

View File

@ -0,0 +1,292 @@
package org.apache.lucene.search.join;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.index.DocTermOrds;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.FieldCache;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefHash;
import java.io.IOException;
abstract class TermsWithScoreCollector extends Collector {
private final static int INITIAL_ARRAY_SIZE = 256;
final String field;
final BytesRefHash collectedTerms = new BytesRefHash();
final ScoreMode scoreMode;
Scorer scorer;
float[] scoreSums = new float[INITIAL_ARRAY_SIZE];
TermsWithScoreCollector(String field, ScoreMode scoreMode) {
this.field = field;
this.scoreMode = scoreMode;
}
public BytesRefHash getCollectedTerms() {
return collectedTerms;
}
public float[] getScoresPerTerm() {
return scoreSums;
}
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}
public boolean acceptsDocsOutOfOrder() {
return true;
}
/**
* Chooses the right {@link TermsWithScoreCollector} implementation.
*
* @param field The field to collect terms for
* @param multipleValuesPerDocument Whether the field to collect terms for has multiple values per document.
* @return a {@link TermsWithScoreCollector} instance
*/
static TermsWithScoreCollector create(String field, boolean multipleValuesPerDocument, ScoreMode scoreMode) {
if (multipleValuesPerDocument) {
switch (scoreMode) {
case Avg:
return new MV.Avg(field);
default:
return new MV(field, scoreMode);
}
} else {
switch (scoreMode) {
case Avg:
return new SV.Avg(field);
default:
return new SV(field, scoreMode);
}
}
}
// impl that works with single value per document
static class SV extends TermsWithScoreCollector {
final BytesRef spare = new BytesRef();
FieldCache.DocTerms fromDocTerms;
SV(String field, ScoreMode scoreMode) {
super(field, scoreMode);
}
public void collect(int doc) throws IOException {
int ord = collectedTerms.add(fromDocTerms.getTerm(doc, spare));
if (ord < 0) {
ord = -ord - 1;
} else {
if (ord >= scoreSums.length) {
scoreSums = ArrayUtil.grow(scoreSums);
}
}
float current = scorer.score();
float existing = scoreSums[ord];
if (Float.compare(existing, 0.0f) == 0) {
scoreSums[ord] = current;
} else {
switch (scoreMode) {
case Total:
scoreSums[ord] = scoreSums[ord] + current;
break;
case Max:
if (current > existing) {
scoreSums[ord] = current;
}
}
}
}
public void setNextReader(AtomicReaderContext context) throws IOException {
fromDocTerms = FieldCache.DEFAULT.getTerms(context.reader(), field);
}
static class Avg extends SV {
int[] scoreCounts = new int[INITIAL_ARRAY_SIZE];
Avg(String field) {
super(field, ScoreMode.Avg);
}
@Override
public void collect(int doc) throws IOException {
int ord = collectedTerms.add(fromDocTerms.getTerm(doc, spare));
if (ord < 0) {
ord = -ord - 1;
} else {
if (ord >= scoreSums.length) {
scoreSums = ArrayUtil.grow(scoreSums);
scoreCounts = ArrayUtil.grow(scoreCounts);
}
}
float current = scorer.score();
float existing = scoreSums[ord];
if (Float.compare(existing, 0.0f) == 0) {
scoreSums[ord] = current;
scoreCounts[ord] = 1;
} else {
scoreSums[ord] = scoreSums[ord] + current;
scoreCounts[ord]++;
}
}
@Override
public float[] getScoresPerTerm() {
if (scoreCounts != null) {
for (int i = 0; i < scoreCounts.length; i++) {
scoreSums[i] = scoreSums[i] / scoreCounts[i];
}
scoreCounts = null;
}
return scoreSums;
}
}
}
// impl that works with multiple values per document
static class MV extends TermsWithScoreCollector {
DocTermOrds fromDocTermOrds;
TermsEnum docTermsEnum;
DocTermOrds.TermOrdsIterator reuse;
MV(String field, ScoreMode scoreMode) {
super(field, scoreMode);
}
public void collect(int doc) throws IOException {
reuse = fromDocTermOrds.lookup(doc, reuse);
int[] buffer = new int[5];
int chunk;
do {
chunk = reuse.read(buffer);
if (chunk == 0) {
return;
}
for (int idx = 0; idx < chunk; idx++) {
int key = buffer[idx];
docTermsEnum.seekExact((long) key);
int ord = collectedTerms.add(docTermsEnum.term());
if (ord < 0) {
ord = -ord - 1;
} else {
if (ord >= scoreSums.length) {
scoreSums = ArrayUtil.grow(scoreSums);
}
}
final float current = scorer.score();
final float existing = scoreSums[ord];
if (Float.compare(existing, 0.0f) == 0) {
scoreSums[ord] = current;
} else {
switch (scoreMode) {
case Total:
scoreSums[ord] = existing + current;
break;
case Max:
if (current > existing) {
scoreSums[ord] = current;
}
}
}
}
} while (chunk >= buffer.length);
}
public void setNextReader(AtomicReaderContext context) throws IOException {
fromDocTermOrds = FieldCache.DEFAULT.getDocTermOrds(context.reader(), field);
docTermsEnum = fromDocTermOrds.getOrdTermsEnum(context.reader());
reuse = null; // LUCENE-3377 needs to be fixed first then this statement can be removed...
}
static class Avg extends MV {
int[] scoreCounts = new int[INITIAL_ARRAY_SIZE];
Avg(String field) {
super(field, ScoreMode.Avg);
}
@Override
public void collect(int doc) throws IOException {
reuse = fromDocTermOrds.lookup(doc, reuse);
int[] buffer = new int[5];
int chunk;
do {
chunk = reuse.read(buffer);
if (chunk == 0) {
return;
}
for (int idx = 0; idx < chunk; idx++) {
int key = buffer[idx];
docTermsEnum.seekExact((long) key);
int ord = collectedTerms.add(docTermsEnum.term());
if (ord < 0) {
ord = -ord - 1;
} else {
if (ord >= scoreSums.length) {
scoreSums = ArrayUtil.grow(scoreSums);
scoreCounts = ArrayUtil.grow(scoreCounts);
}
}
float current = scorer.score();
float existing = scoreSums[ord];
if (Float.compare(existing, 0.0f) == 0) {
scoreSums[ord] = current;
scoreCounts[ord] = 1;
} else {
scoreSums[ord] = scoreSums[ord] + current;
scoreCounts[ord]++;
}
}
} while (chunk >= buffer.length);
}
@Override
public float[] getScoresPerTerm() {
if (scoreCounts != null) {
for (int i = 0; i < scoreCounts.length; i++) {
scoreSums[i] = scoreSums[i] / scoreCounts[i];
}
scoreCounts = null;
}
return scoreSums;
}
}
}
}

View File

@ -33,7 +33,6 @@ import org.apache.lucene.search.Filter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Scorer.ChildScorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.util.ArrayUtil;
@ -82,24 +81,8 @@ import org.apache.lucene.util.FixedBitSet;
*
* @lucene.experimental
*/
public class ToParentBlockJoinQuery extends Query {
/** How to aggregate multiple child hit scores into a
* single parent score. */
public static enum ScoreMode {
/** Do no scoring. */
None,
/** Parent hit's score is the average of all child
scores. */
Avg,
/** Parent hit's score is the max of all child
scores. */
Max,
/** Parent hit's score is the sum of all child
scores. */
Total};
private final Filter parentsFilter;
private final Query childQuery;

View File

@ -56,7 +56,7 @@
any query matching parent documents, creating the joined query
matching only child documents.
<h2>Search-time joins</h2>
<h2>Query-time joins</h2>
<p>
The query time joining is index term based and implemented as two pass search. The first pass collects all the terms from a fromField
@ -68,22 +68,26 @@
<li><code>fromField</code>: The from field to join from.
<li><code>fromQuery</code>: The query executed to collect the from terms. This is usually the user specified query.
<li><code>multipleValuesPerDocument</code>: Whether the fromField contains more than one value per document
<li><code>scoreMode</code>: Defines how scores are translated to the other join side. If you don't care about scoring
use {@link org.apache.lucene.search.join.ScoreMode#None} mode. This will disable scoring and is therefore more
efficient (requires less memory and is faster).
<li><code>toField</code>: The to field to join to
</ul>
<p>
Basically the query-time joining is accessible from one static method. The user of this method supplies the method
with the described input and a <code>IndexSearcher</code> where the from terms need to be collected from. The returned
query can be executed with the same <code>IndexSearcher</code>, but also with another <code>IndexSearcher</code>.
Example usage of the {@link org.apache.lucene.search.join.JoinUtil#createJoinQuery(String, boolean, String, org.apache.lucene.search.Query, org.apache.lucene.search.IndexSearcher)
Example usage of the {@link org.apache.lucene.search.join.JoinUtil#createJoinQuery(String, boolean, String, org.apache.lucene.search.Query, org.apache.lucene.search.IndexSearcher, org.apache.lucene.search.join.ScoreMode)
JoinUtil.createJoinQuery()} :
</p>
<pre class="prettyprint">
String fromField = "from"; // Name of the from field
boolean multipleValuesPerDocument = false; // Set only yo true in the case when your fromField has multiple values per document in your index
String toField = "to"; // Name of the to field
ScoreMode scoreMode = ScoreMode.Max // Defines how the scores are translated into the other side of the join.
Query fromQuery = new TermQuery(new Term("content", searchTerm)); // Query executed to collect from values to join to the to values
Query joinQuery = JoinUtil.createJoinQuery(fromField, multipleValuesPerDocument, toField, fromQuery, fromSearcher);
Query joinQuery = JoinUtil.createJoinQuery(fromField, multipleValuesPerDocument, toField, fromQuery, fromSearcher, scoreMode);
TopDocs topDocs = toSearcher.search(joinQuery, 10); // Note: toSearcher can be the same as the fromSearcher
// Render topDocs...
</pre>

View File

@ -96,7 +96,7 @@ public class TestBlockJoin extends LuceneTestCase {
// Wrap the child document query to 'join' any matches
// up to corresponding parent:
ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, ToParentBlockJoinQuery.ScoreMode.Avg);
ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, ScoreMode.Avg);
// Combine the parent and nested child queries into a single query for a candidate
BooleanQuery fullQuery = new BooleanQuery();
@ -198,7 +198,7 @@ public class TestBlockJoin extends LuceneTestCase {
// Wrap the child document query to 'join' any matches
// up to corresponding parent:
ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, ToParentBlockJoinQuery.ScoreMode.Avg);
ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, ScoreMode.Avg);
assertEquals("no filter - both passed", 2, s.search(childJoinQuery, 10).totalHits);
@ -259,7 +259,7 @@ public class TestBlockJoin extends LuceneTestCase {
w.close();
IndexSearcher s = newSearcher(r);
ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(new MatchAllDocsQuery(), new QueryWrapperFilter(new MatchAllDocsQuery()), ToParentBlockJoinQuery.ScoreMode.Avg);
ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(new MatchAllDocsQuery(), new QueryWrapperFilter(new MatchAllDocsQuery()), ScoreMode.Avg);
s.search(q, 10);
BooleanQuery bq = new BooleanQuery();
bq.setBoost(2f); // we boost the BQ
@ -493,15 +493,15 @@ public class TestBlockJoin extends LuceneTestCase {
}
final int x = random().nextInt(4);
final ToParentBlockJoinQuery.ScoreMode agg;
final ScoreMode agg;
if (x == 0) {
agg = ToParentBlockJoinQuery.ScoreMode.None;
agg = ScoreMode.None;
} else if (x == 1) {
agg = ToParentBlockJoinQuery.ScoreMode.Max;
agg = ScoreMode.Max;
} else if (x == 2) {
agg = ToParentBlockJoinQuery.ScoreMode.Total;
agg = ScoreMode.Total;
} else {
agg = ToParentBlockJoinQuery.ScoreMode.Avg;
agg = ScoreMode.Avg;
}
final ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, agg);
@ -584,7 +584,7 @@ public class TestBlockJoin extends LuceneTestCase {
final boolean trackScores;
final boolean trackMaxScore;
if (agg == ToParentBlockJoinQuery.ScoreMode.None) {
if (agg == ScoreMode.None) {
trackScores = false;
trackMaxScore = false;
} else {
@ -881,8 +881,8 @@ public class TestBlockJoin extends LuceneTestCase {
// Wrap the child document query to 'join' any matches
// up to corresponding parent:
ToParentBlockJoinQuery childJobJoinQuery = new ToParentBlockJoinQuery(childJobQuery, parentsFilter, ToParentBlockJoinQuery.ScoreMode.Avg);
ToParentBlockJoinQuery childQualificationJoinQuery = new ToParentBlockJoinQuery(childQualificationQuery, parentsFilter, ToParentBlockJoinQuery.ScoreMode.Avg);
ToParentBlockJoinQuery childJobJoinQuery = new ToParentBlockJoinQuery(childJobQuery, parentsFilter, ScoreMode.Avg);
ToParentBlockJoinQuery childQualificationJoinQuery = new ToParentBlockJoinQuery(childQualificationQuery, parentsFilter, ScoreMode.Avg);
// Combine the parent and nested child queries into a single query for a candidate
BooleanQuery fullQuery = new BooleanQuery();
@ -952,7 +952,7 @@ public class TestBlockJoin extends LuceneTestCase {
new QueryWrapperFilter(
new TermQuery(new Term("parent", "1"))));
ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(tq, parentFilter, ToParentBlockJoinQuery.ScoreMode.Avg);
ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(tq, parentFilter, ScoreMode.Avg);
Weight weight = s.createNormalizedWeight(q);
DocIdSetIterator disi = weight.scorer(s.getIndexReader().getTopReaderContext().leaves()[0], true, true, null);
assertEquals(1, disi.advance(1));
@ -986,7 +986,7 @@ public class TestBlockJoin extends LuceneTestCase {
new QueryWrapperFilter(
new TermQuery(new Term("isparent", "yes"))));
ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(tq, parentFilter, ToParentBlockJoinQuery.ScoreMode.Avg);
ToParentBlockJoinQuery q = new ToParentBlockJoinQuery(tq, parentFilter, ScoreMode.Avg);
Weight weight = s.createNormalizedWeight(q);
DocIdSetIterator disi = weight.scorer(s.getIndexReader().getTopReaderContext().leaves()[0], true, true, null);
assertEquals(2, disi.advance(0));

View File

@ -22,8 +22,26 @@ import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.index.DocTermOrds;
import org.apache.lucene.index.DocsEnum;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FieldCache;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
@ -49,45 +67,45 @@ public class TestJoinUtil extends LuceneTestCase {
// 0
Document doc = new Document();
doc.add(new Field("description", "random text", TextField.TYPE_STORED));
doc.add(new Field("name", "name1", TextField.TYPE_STORED));
doc.add(new Field(idField, "1", TextField.TYPE_STORED));
doc.add(new Field("description", "random text", TextField.TYPE_UNSTORED));
doc.add(new Field("name", "name1", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "1", TextField.TYPE_UNSTORED));
w.addDocument(doc);
// 1
doc = new Document();
doc.add(new Field("price", "10.0", TextField.TYPE_STORED));
doc.add(new Field(idField, "2", TextField.TYPE_STORED));
doc.add(new Field(toField, "1", TextField.TYPE_STORED));
doc.add(new Field("price", "10.0", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "2", TextField.TYPE_UNSTORED));
doc.add(new Field(toField, "1", TextField.TYPE_UNSTORED));
w.addDocument(doc);
// 2
doc = new Document();
doc.add(new Field("price", "20.0", TextField.TYPE_STORED));
doc.add(new Field(idField, "3", TextField.TYPE_STORED));
doc.add(new Field(toField, "1", TextField.TYPE_STORED));
doc.add(new Field("price", "20.0", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "3", TextField.TYPE_UNSTORED));
doc.add(new Field(toField, "1", TextField.TYPE_UNSTORED));
w.addDocument(doc);
// 3
doc = new Document();
doc.add(new Field("description", "more random text", TextField.TYPE_STORED));
doc.add(new Field("name", "name2", TextField.TYPE_STORED));
doc.add(new Field(idField, "4", TextField.TYPE_STORED));
doc.add(new Field("description", "more random text", TextField.TYPE_UNSTORED));
doc.add(new Field("name", "name2", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "4", TextField.TYPE_UNSTORED));
w.addDocument(doc);
w.commit();
// 4
doc = new Document();
doc.add(new Field("price", "10.0", TextField.TYPE_STORED));
doc.add(new Field(idField, "5", TextField.TYPE_STORED));
doc.add(new Field(toField, "4", TextField.TYPE_STORED));
doc.add(new Field("price", "10.0", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "5", TextField.TYPE_UNSTORED));
doc.add(new Field(toField, "4", TextField.TYPE_UNSTORED));
w.addDocument(doc);
// 5
doc = new Document();
doc.add(new Field("price", "20.0", TextField.TYPE_STORED));
doc.add(new Field(idField, "6", TextField.TYPE_STORED));
doc.add(new Field(toField, "4", TextField.TYPE_STORED));
doc.add(new Field("price", "20.0", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "6", TextField.TYPE_UNSTORED));
doc.add(new Field(toField, "4", TextField.TYPE_UNSTORED));
w.addDocument(doc);
IndexSearcher indexSearcher = new IndexSearcher(w.getReader());
@ -95,21 +113,21 @@ public class TestJoinUtil extends LuceneTestCase {
// Search for product
Query joinQuery =
JoinUtil.createJoinQuery(idField, false, toField, new TermQuery(new Term("name", "name2")), indexSearcher);
JoinUtil.createJoinQuery(idField, false, toField, new TermQuery(new Term("name", "name2")), indexSearcher, ScoreMode.None);
TopDocs result = indexSearcher.search(joinQuery, 10);
assertEquals(2, result.totalHits);
assertEquals(4, result.scoreDocs[0].doc);
assertEquals(5, result.scoreDocs[1].doc);
joinQuery = JoinUtil.createJoinQuery(idField, false, toField, new TermQuery(new Term("name", "name1")), indexSearcher);
joinQuery = JoinUtil.createJoinQuery(idField, false, toField, new TermQuery(new Term("name", "name1")), indexSearcher, ScoreMode.None);
result = indexSearcher.search(joinQuery, 10);
assertEquals(2, result.totalHits);
assertEquals(1, result.scoreDocs[0].doc);
assertEquals(2, result.scoreDocs[1].doc);
// Search for offer
joinQuery = JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("id", "5")), indexSearcher);
joinQuery = JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("id", "5")), indexSearcher, ScoreMode.None);
result = indexSearcher.search(joinQuery, 10);
assertEquals(1, result.totalHits);
assertEquals(3, result.scoreDocs[0].doc);
@ -118,6 +136,96 @@ public class TestJoinUtil extends LuceneTestCase {
dir.close();
}
public void testSimpleWithScoring() throws Exception {
final String idField = "id";
final String toField = "movieId";
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(
random(),
dir,
newIndexWriterConfig(TEST_VERSION_CURRENT,
new MockAnalyzer(random())).setMergePolicy(newLogMergePolicy()));
// 0
Document doc = new Document();
doc.add(new Field("description", "A random movie", TextField.TYPE_UNSTORED));
doc.add(new Field("name", "Movie 1", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "1", TextField.TYPE_UNSTORED));
w.addDocument(doc);
// 1
doc = new Document();
doc.add(new Field("subtitle", "The first subtitle of this movie", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "2", TextField.TYPE_UNSTORED));
doc.add(new Field(toField, "1", TextField.TYPE_UNSTORED));
w.addDocument(doc);
// 2
doc = new Document();
doc.add(new Field("subtitle", "random subtitle; random event movie", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "3", TextField.TYPE_UNSTORED));
doc.add(new Field(toField, "1", TextField.TYPE_UNSTORED));
w.addDocument(doc);
// 3
doc = new Document();
doc.add(new Field("description", "A second random movie", TextField.TYPE_UNSTORED));
doc.add(new Field("name", "Movie 2", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "4", TextField.TYPE_UNSTORED));
w.addDocument(doc);
w.commit();
// 4
doc = new Document();
doc.add(new Field("subtitle", "a very random event happened during christmas night", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "5", TextField.TYPE_UNSTORED));
doc.add(new Field(toField, "4", TextField.TYPE_UNSTORED));
w.addDocument(doc);
// 5
doc = new Document();
doc.add(new Field("subtitle", "movie end movie test 123 test 123 random", TextField.TYPE_UNSTORED));
doc.add(new Field(idField, "6", TextField.TYPE_UNSTORED));
doc.add(new Field(toField, "4", TextField.TYPE_UNSTORED));
w.addDocument(doc);
IndexSearcher indexSearcher = new IndexSearcher(w.getReader());
w.close();
// Search for movie via subtitle
Query joinQuery =
JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("subtitle", "random")), indexSearcher, ScoreMode.Max);
TopDocs result = indexSearcher.search(joinQuery, 10);
assertEquals(2, result.totalHits);
assertEquals(0, result.scoreDocs[0].doc);
assertEquals(3, result.scoreDocs[1].doc);
// Score mode max.
joinQuery = JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("subtitle", "movie")), indexSearcher, ScoreMode.Max);
result = indexSearcher.search(joinQuery, 10);
assertEquals(2, result.totalHits);
assertEquals(3, result.scoreDocs[0].doc);
assertEquals(0, result.scoreDocs[1].doc);
// Score mode total
joinQuery = JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("subtitle", "movie")), indexSearcher, ScoreMode.Total);
result = indexSearcher.search(joinQuery, 10);
assertEquals(2, result.totalHits);
assertEquals(0, result.scoreDocs[0].doc);
assertEquals(3, result.scoreDocs[1].doc);
//Score mode avg
joinQuery = JoinUtil.createJoinQuery(toField, false, idField, new TermQuery(new Term("subtitle", "movie")), indexSearcher, ScoreMode.Avg);
result = indexSearcher.search(joinQuery, 10);
assertEquals(2, result.totalHits);
assertEquals(3, result.scoreDocs[0].doc);
assertEquals(0, result.scoreDocs[1].doc);
indexSearcher.getIndexReader().close();
dir.close();
}
@Test
public void testSingleValueRandomJoin() throws Exception {
int maxIndexIter = _TestUtil.nextInt(random(), 6, 12);
@ -160,15 +268,20 @@ public class TestJoinUtil extends LuceneTestCase {
String randomValue = context.randomUniqueValues[r];
FixedBitSet expectedResult = createExpectedResult(randomValue, from, indexSearcher.getIndexReader(), context);
Query actualQuery = new TermQuery(new Term("value", randomValue));
final Query actualQuery = new TermQuery(new Term("value", randomValue));
if (VERBOSE) {
System.out.println("actualQuery=" + actualQuery);
}
Query joinQuery;
final ScoreMode scoreMode = ScoreMode.values()[random().nextInt(ScoreMode.values().length)];
if (VERBOSE) {
System.out.println("scoreMode=" + scoreMode);
}
final Query joinQuery;
if (from) {
joinQuery = JoinUtil.createJoinQuery("from", multipleValuesPerDocument, "to", actualQuery, indexSearcher);
joinQuery = JoinUtil.createJoinQuery("from", multipleValuesPerDocument, "to", actualQuery, indexSearcher, scoreMode);
} else {
joinQuery = JoinUtil.createJoinQuery("to", multipleValuesPerDocument, "from", actualQuery, indexSearcher);
joinQuery = JoinUtil.createJoinQuery("to", multipleValuesPerDocument, "from", actualQuery, indexSearcher, scoreMode);
}
if (VERBOSE) {
System.out.println("joinQuery=" + joinQuery);
@ -176,26 +289,30 @@ public class TestJoinUtil extends LuceneTestCase {
// Need to know all documents that have matches. TopDocs doesn't give me that and then I'd be also testing TopDocsCollector...
final FixedBitSet actualResult = new FixedBitSet(indexSearcher.getIndexReader().maxDoc());
final TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(10, false);
indexSearcher.search(joinQuery, new Collector() {
int docBase;
public void collect(int doc) throws IOException {
actualResult.set(doc + docBase);
topScoreDocCollector.collect(doc);
}
public void setNextReader(AtomicReaderContext context) throws IOException {
docBase = context.docBase;
topScoreDocCollector.setNextReader(context);
}
public void setScorer(Scorer scorer) throws IOException {
topScoreDocCollector.setScorer(scorer);
}
public boolean acceptsDocsOutOfOrder() {
return true;
return topScoreDocCollector.acceptsDocsOutOfOrder();
}
});
// Asserting bit set...
if (VERBOSE) {
System.out.println("expected cardinality:" + expectedResult.cardinality());
DocIdSetIterator iterator = expectedResult.iterator();
@ -208,8 +325,28 @@ public class TestJoinUtil extends LuceneTestCase {
System.out.println(String.format("Actual doc[%d] with id value %s", doc, indexSearcher.doc(doc).get("id")));
}
}
assertEquals(expectedResult, actualResult);
// Asserting TopDocs...
TopDocs expectedTopDocs = createExpectedTopDocs(randomValue, from, scoreMode, context);
TopDocs actualTopDocs = topScoreDocCollector.topDocs();
assertEquals(expectedTopDocs.totalHits, actualTopDocs.totalHits);
assertEquals(expectedTopDocs.scoreDocs.length, actualTopDocs.scoreDocs.length);
if (scoreMode == ScoreMode.None) {
continue;
}
assertEquals(expectedTopDocs.getMaxScore(), actualTopDocs.getMaxScore(), 0.0f);
for (int i = 0; i < expectedTopDocs.scoreDocs.length; i++) {
if (VERBOSE) {
System.out.printf("Expected doc: %d | Actual doc: %d\n", expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc);
System.out.printf("Expected score: %f | Actual score: %f\n", expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score);
}
assertEquals(expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc);
assertEquals(expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score, 0.0f);
Explanation explanation = indexSearcher.explain(joinQuery, expectedTopDocs.scoreDocs[i].doc);
assertEquals(expectedTopDocs.scoreDocs[i].score, explanation.getValue(), 0.0f);
}
}
topLevelReader.close();
dir.close();
@ -238,20 +375,21 @@ public class TestJoinUtil extends LuceneTestCase {
context.randomUniqueValues[i] = uniqueRandomValue;
}
RandomDoc[] docs = new RandomDoc[nDocs];
for (int i = 0; i < nDocs; i++) {
String id = Integer.toString(i);
int randomI = random().nextInt(context.randomUniqueValues.length);
String value = context.randomUniqueValues[randomI];
Document document = new Document();
document.add(newField(random(), "id", id, TextField.TYPE_STORED));
document.add(newField(random(), "value", value, TextField.TYPE_STORED));
document.add(newField(random(), "id", id, TextField.TYPE_UNSTORED));
document.add(newField(random(), "value", value, TextField.TYPE_UNSTORED));
boolean from = context.randomFrom[randomI];
int numberOfLinkValues = multipleValuesPerDocument ? 2 + random().nextInt(10) : 1;
RandomDoc doc = new RandomDoc(id, numberOfLinkValues, value);
docs[i] = new RandomDoc(id, numberOfLinkValues, value, from);
for (int j = 0; j < numberOfLinkValues; j++) {
String linkValue = context.randomUniqueValues[random().nextInt(context.randomUniqueValues.length)];
doc.linkValues.add(linkValue);
docs[i].linkValues.add(linkValue);
if (from) {
if (!context.fromDocuments.containsKey(linkValue)) {
context.fromDocuments.put(linkValue, new ArrayList<RandomDoc>());
@ -260,9 +398,9 @@ public class TestJoinUtil extends LuceneTestCase {
context.randomValueFromDocs.put(value, new ArrayList<RandomDoc>());
}
context.fromDocuments.get(linkValue).add(doc);
context.randomValueFromDocs.get(value).add(doc);
document.add(newField(random(), "from", linkValue, TextField.TYPE_STORED));
context.fromDocuments.get(linkValue).add(docs[i]);
context.randomValueFromDocs.get(value).add(docs[i]);
document.add(newField(random(), "from", linkValue, TextField.TYPE_UNSTORED));
} else {
if (!context.toDocuments.containsKey(linkValue)) {
context.toDocuments.put(linkValue, new ArrayList<RandomDoc>());
@ -271,9 +409,9 @@ public class TestJoinUtil extends LuceneTestCase {
context.randomValueToDocs.put(value, new ArrayList<RandomDoc>());
}
context.toDocuments.get(linkValue).add(doc);
context.randomValueToDocs.get(value).add(doc);
document.add(newField(random(), "to", linkValue, TextField.TYPE_STORED));
context.toDocuments.get(linkValue).add(docs[i]);
context.randomValueToDocs.get(value).add(docs[i]);
document.add(newField(random(), "to", linkValue, TextField.TYPE_UNSTORED));
}
}
@ -289,12 +427,235 @@ public class TestJoinUtil extends LuceneTestCase {
w.commit();
}
if (VERBOSE) {
System.out.println("Added document[" + i + "]: " + document);
System.out.println("Added document[" + docs[i].id + "]: " + document);
}
}
// Pre-compute all possible hits for all unique random values. On top of this also compute all possible score for
// any ScoreMode.
IndexSearcher fromSearcher = newSearcher(fromWriter.getReader());
IndexSearcher toSearcher = newSearcher(toWriter.getReader());
for (int i = 0; i < context.randomUniqueValues.length; i++) {
String uniqueRandomValue = context.randomUniqueValues[i];
final String fromField;
final String toField;
final Map<String, Map<Integer, JoinScore>> queryVals;
if (context.randomFrom[i]) {
fromField = "from";
toField = "to";
queryVals = context.fromHitsToJoinScore;
} else {
fromField = "to";
toField = "from";
queryVals = context.toHitsToJoinScore;
}
final Map<BytesRef, JoinScore> joinValueToJoinScores = new HashMap<BytesRef, JoinScore>();
if (multipleValuesPerDocument) {
fromSearcher.search(new TermQuery(new Term("value", uniqueRandomValue)), new Collector() {
private Scorer scorer;
private DocTermOrds docTermOrds;
private TermsEnum docTermsEnum;
private DocTermOrds.TermOrdsIterator reuse;
public void collect(int doc) throws IOException {
if (docTermOrds.isEmpty()) {
return;
}
reuse = docTermOrds.lookup(doc, reuse);
int[] buffer = new int[5];
int chunk;
do {
chunk = reuse.read(buffer);
if (chunk == 0) {
return;
}
for (int idx = 0; idx < chunk; idx++) {
int key = buffer[idx];
docTermsEnum.seekExact((long) key);
BytesRef joinValue = docTermsEnum.term();
if (joinValue == null) {
continue;
}
JoinScore joinScore = joinValueToJoinScores.get(joinValue);
if (joinScore == null) {
joinValueToJoinScores.put(BytesRef.deepCopyOf(joinValue), joinScore = new JoinScore());
}
joinScore.addScore(scorer.score());
}
} while (chunk >= buffer.length);
}
public void setNextReader(AtomicReaderContext context) throws IOException {
docTermOrds = FieldCache.DEFAULT.getDocTermOrds(context.reader(), fromField);
docTermsEnum = docTermOrds.getOrdTermsEnum(context.reader());
reuse = null;
}
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}
public boolean acceptsDocsOutOfOrder() {
return false;
}
});
} else {
fromSearcher.search(new TermQuery(new Term("value", uniqueRandomValue)), new Collector() {
private Scorer scorer;
private FieldCache.DocTerms terms;
private final BytesRef spare = new BytesRef();
public void collect(int doc) throws IOException {
BytesRef joinValue = terms.getTerm(doc, spare);
if (joinValue == null) {
return;
}
JoinScore joinScore = joinValueToJoinScores.get(joinValue);
if (joinScore == null) {
joinValueToJoinScores.put(BytesRef.deepCopyOf(joinValue), joinScore = new JoinScore());
}
joinScore.addScore(scorer.score());
}
public void setNextReader(AtomicReaderContext context) throws IOException {
terms = FieldCache.DEFAULT.getTerms(context.reader(), fromField);
}
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}
public boolean acceptsDocsOutOfOrder() {
return false;
}
});
}
final Map<Integer, JoinScore> docToJoinScore = new HashMap<Integer, JoinScore>();
if (multipleValuesPerDocument) {
toSearcher.search(new MatchAllDocsQuery(), new Collector() {
private DocTermOrds docTermOrds;
private TermsEnum docTermsEnum;
private DocTermOrds.TermOrdsIterator reuse;
private int docBase;
public void collect(int doc) throws IOException {
if (docTermOrds.isEmpty()) {
return;
}
reuse = docTermOrds.lookup(doc, reuse);
int[] buffer = new int[5];
int chunk;
do {
chunk = reuse.read(buffer);
if (chunk == 0) {
return;
}
for (int idx = 0; idx < chunk; idx++) {
int key = buffer[idx];
docTermsEnum.seekExact((long) key);
JoinScore joinScore = joinValueToJoinScores.get(docTermsEnum.term());
if (joinScore == null) {
continue;
}
Integer basedDoc = docBase + doc;
// First encountered join value determines the score.
// Something to keep in mind for many-to-many relations.
if (!docToJoinScore.containsKey(basedDoc)) {
docToJoinScore.put(basedDoc, joinScore);
}
}
} while (chunk >= buffer.length);
}
public void setNextReader(AtomicReaderContext context) throws IOException {
docBase = context.docBase;
docTermOrds = FieldCache.DEFAULT.getDocTermOrds(context.reader(), toField);
docTermsEnum = docTermOrds.getOrdTermsEnum(context.reader());
reuse = null;
}
public boolean acceptsDocsOutOfOrder() {return false;}
public void setScorer(Scorer scorer) throws IOException {}
});
} else {
toSearcher.search(new MatchAllDocsQuery(), new Collector() {
private FieldCache.DocTerms terms;
private int docBase;
private final BytesRef spare = new BytesRef();
public void collect(int doc) throws IOException {
JoinScore joinScore = joinValueToJoinScores.get(terms.getTerm(doc, spare));
if (joinScore == null) {
return;
}
docToJoinScore.put(docBase + doc, joinScore);
}
public void setNextReader(AtomicReaderContext context) throws IOException {
terms = FieldCache.DEFAULT.getTerms(context.reader(), toField);
docBase = context.docBase;
}
public boolean acceptsDocsOutOfOrder() {return false;}
public void setScorer(Scorer scorer) throws IOException {}
});
}
queryVals.put(uniqueRandomValue, docToJoinScore);
}
fromSearcher.getIndexReader().close();
toSearcher.getIndexReader().close();
return context;
}
private TopDocs createExpectedTopDocs(String queryValue,
final boolean from,
final ScoreMode scoreMode,
IndexIterationContext context) throws IOException {
Map<Integer, JoinScore> hitsToJoinScores;
if (from) {
hitsToJoinScores = context.fromHitsToJoinScore.get(queryValue);
} else {
hitsToJoinScores = context.toHitsToJoinScore.get(queryValue);
}
List<Map.Entry<Integer,JoinScore>> hits = new ArrayList<Map.Entry<Integer, JoinScore>>(hitsToJoinScores.entrySet());
Collections.sort(hits, new Comparator<Map.Entry<Integer, JoinScore>>() {
public int compare(Map.Entry<Integer, JoinScore> hit1, Map.Entry<Integer, JoinScore> hit2) {
float score1 = hit1.getValue().score(scoreMode);
float score2 = hit2.getValue().score(scoreMode);
int cmp = Float.compare(score2, score1);
if (cmp != 0) {
return cmp;
}
return hit1.getKey() - hit2.getKey();
}
});
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(10, hits.size())];
for (int i = 0; i < scoreDocs.length; i++) {
Map.Entry<Integer,JoinScore> hit = hits.get(i);
scoreDocs[i] = new ScoreDoc(hit.getKey(), hit.getValue().score(scoreMode));
}
return new TopDocs(hits.size(), scoreDocs, hits.isEmpty() ? Float.NaN : hits.get(0).getValue().score(scoreMode));
}
private FixedBitSet createExpectedResult(String queryValue, boolean from, IndexReader topLevelReader, IndexIterationContext context) throws IOException {
final Map<String, List<RandomDoc>> randomValueDocs;
final Map<String, List<RandomDoc>> linkValueDocuments;
@ -339,6 +700,9 @@ public class TestJoinUtil extends LuceneTestCase {
Map<String, List<RandomDoc>> randomValueFromDocs = new HashMap<String, List<RandomDoc>>();
Map<String, List<RandomDoc>> randomValueToDocs = new HashMap<String, List<RandomDoc>>();
Map<String, Map<Integer, JoinScore>> fromHitsToJoinScore = new HashMap<String, Map<Integer, JoinScore>>();
Map<String, Map<Integer, JoinScore>> toHitsToJoinScore = new HashMap<String, Map<Integer, JoinScore>>();
}
private static class RandomDoc {
@ -346,12 +710,44 @@ public class TestJoinUtil extends LuceneTestCase {
final String id;
final List<String> linkValues;
final String value;
final boolean from;
private RandomDoc(String id, int numberOfLinkValues, String value) {
private RandomDoc(String id, int numberOfLinkValues, String value, boolean from) {
this.id = id;
this.from = from;
linkValues = new ArrayList<String>(numberOfLinkValues);
this.value = value;
}
}
private static class JoinScore {
float maxScore;
float total;
int count;
void addScore(float score) {
total += score;
if (score > maxScore) {
maxScore = score;
}
count++;
}
float score(ScoreMode mode) {
switch (mode) {
case None:
return 1.0f;
case Total:
return total;
case Avg:
return total / count;
case Max:
return maxScore;
}
throw new IllegalArgumentException("Unsupported ScoreMode: " + mode);
}
}
}