mirror of https://github.com/apache/lucene.git
LUCENE-6389: Added ScoreMode.Min that aggregates the lowest child score to the parent hit.
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1674622 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
e699256bbb
commit
07beb24d93
|
@ -67,6 +67,9 @@ New Features
|
|||
StoredFieldVisitor.stringField API to take UTF-8 byte[] instead of
|
||||
String (Mike McCandless)
|
||||
|
||||
* LUCENE-6389: Added ScoreMode.Min that aggregates the lowest child score
|
||||
to the parent hit. (Martijn van Groningen, Adrien Grand)
|
||||
|
||||
Optimizations
|
||||
|
||||
* LUCENE-6379: IndexWriter.deleteDocuments(Query...) now detects if
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.lucene.util.LongBitSet;
|
|||
import org.apache.lucene.util.LongValues;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
||||
abstract class GlobalOrdinalsWithScoreCollector implements Collector {
|
||||
|
||||
|
@ -44,7 +45,7 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
|
|||
this.field = field;
|
||||
this.ordinalMap = ordinalMap;
|
||||
this.collectedOrds = new LongBitSet(valueCount);
|
||||
this.scores = new Scores(valueCount);
|
||||
this.scores = new Scores(valueCount, unset());
|
||||
}
|
||||
|
||||
public LongBitSet getCollectorOrdinals() {
|
||||
|
@ -57,6 +58,8 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
|
|||
|
||||
protected abstract void doScore(int globalOrd, float existingScore, float newScore);
|
||||
|
||||
protected abstract float unset();
|
||||
|
||||
@Override
|
||||
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
|
||||
SortedDocValues docTermOrds = DocValues.getSorted(context.reader(), field);
|
||||
|
@ -128,6 +131,23 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
|
|||
}
|
||||
}
|
||||
|
||||
static final class Min extends GlobalOrdinalsWithScoreCollector {
|
||||
|
||||
public Min(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) {
|
||||
super(field, ordinalMap, valueCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doScore(int globalOrd, float existingScore, float newScore) {
|
||||
scores.setScore(globalOrd, Math.min(existingScore, newScore));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float unset() {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
static final class Max extends GlobalOrdinalsWithScoreCollector {
|
||||
|
||||
public Max(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) {
|
||||
|
@ -139,6 +159,10 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
|
|||
scores.setScore(globalOrd, Math.max(existingScore, newScore));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float unset() {
|
||||
return Float.NEGATIVE_INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
static final class Sum extends GlobalOrdinalsWithScoreCollector {
|
||||
|
@ -152,6 +176,10 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
|
|||
scores.setScore(globalOrd, existingScore + newScore);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float unset() {
|
||||
return 0f;
|
||||
}
|
||||
}
|
||||
|
||||
static final class Avg extends GlobalOrdinalsWithScoreCollector {
|
||||
|
@ -173,6 +201,11 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
|
|||
public float score(int globalOrdinal) {
|
||||
return scores.getScore(globalOrdinal) / occurrences.getOccurrence(globalOrdinal);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float unset() {
|
||||
return 0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Because the global ordinal is directly used as a key to a score we should be somewhat smart about allocation
|
||||
|
@ -188,10 +221,12 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
|
|||
static final class Scores {
|
||||
|
||||
final float[][] blocks;
|
||||
final float unset;
|
||||
|
||||
private Scores(long valueCount) {
|
||||
private Scores(long valueCount, float unset) {
|
||||
long blockSize = valueCount + arraySize - 1;
|
||||
blocks = new float[(int) ((blockSize) / arraySize)][];
|
||||
this.unset = unset;
|
||||
}
|
||||
|
||||
public void setScore(int globalOrdinal, float score) {
|
||||
|
@ -200,6 +235,9 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
|
|||
float[] scores = blocks[block];
|
||||
if (scores == null) {
|
||||
blocks[block] = scores = new float[arraySize];
|
||||
if (unset != 0f) {
|
||||
Arrays.fill(scores, unset);
|
||||
}
|
||||
}
|
||||
scores[offset] = score;
|
||||
}
|
||||
|
@ -212,7 +250,7 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
|
|||
if (scores != null) {
|
||||
score = scores[offset];
|
||||
} else {
|
||||
score = 0f;
|
||||
score = unset;
|
||||
}
|
||||
return score;
|
||||
}
|
||||
|
|
|
@ -79,6 +79,7 @@ public final class JoinUtil {
|
|||
return new TermsQuery(toField, fromQuery, termsCollector.getCollectorTerms());
|
||||
case Total:
|
||||
case Max:
|
||||
case Min:
|
||||
case Avg:
|
||||
TermsWithScoreCollector termsWithScoreCollector =
|
||||
TermsWithScoreCollector.create(fromField, multipleValuesPerDocument, scoreMode);
|
||||
|
@ -144,10 +145,10 @@ public final class JoinUtil {
|
|||
valueCount = ordinalMap.getValueCount();
|
||||
}
|
||||
|
||||
Query rewrittenFromQuery = searcher.rewrite(fromQuery);
|
||||
final Query rewrittenFromQuery = searcher.rewrite(fromQuery);
|
||||
if (scoreMode == ScoreMode.None) {
|
||||
GlobalOrdinalsCollector globalOrdinalsCollector = new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount);
|
||||
searcher.search(fromQuery, globalOrdinalsCollector);
|
||||
searcher.search(rewrittenFromQuery, globalOrdinalsCollector);
|
||||
return new GlobalOrdinalsQuery(globalOrdinalsCollector.getCollectorOrdinals(), joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader);
|
||||
}
|
||||
|
||||
|
@ -156,6 +157,9 @@ public final class JoinUtil {
|
|||
case Total:
|
||||
globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount);
|
||||
break;
|
||||
case Min:
|
||||
globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Min(joinField, ordinalMap, valueCount);
|
||||
break;
|
||||
case Max:
|
||||
globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount);
|
||||
break;
|
||||
|
@ -165,7 +169,7 @@ public final class JoinUtil {
|
|||
default:
|
||||
throw new IllegalArgumentException(String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode));
|
||||
}
|
||||
searcher.search(fromQuery, globalOrdinalsWithScoreCollector);
|
||||
searcher.search(rewrittenFromQuery, globalOrdinalsWithScoreCollector);
|
||||
return new GlobalOrdinalsWithScoreQuery(globalOrdinalsWithScoreCollector, joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader);
|
||||
}
|
||||
|
||||
|
|
|
@ -40,6 +40,11 @@ public enum ScoreMode {
|
|||
/**
|
||||
* Parent hit's score is the sum of all child scores.
|
||||
*/
|
||||
Total
|
||||
Total,
|
||||
|
||||
/**
|
||||
* Parent hit's score is the min of all child scores.
|
||||
*/
|
||||
Min
|
||||
|
||||
}
|
||||
|
|
|
@ -17,20 +17,21 @@ package org.apache.lucene.search.join;
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.BinaryDocValues;
|
||||
import org.apache.lucene.index.DocValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.SortedSetDocValues;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.SimpleCollector;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRefHash;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
||||
abstract class TermsWithScoreCollector extends SimpleCollector {
|
||||
|
||||
private final static int INITIAL_ARRAY_SIZE = 256;
|
||||
private final static int INITIAL_ARRAY_SIZE = 0;
|
||||
|
||||
final String field;
|
||||
final BytesRefHash collectedTerms = new BytesRefHash();
|
||||
|
@ -42,6 +43,11 @@ abstract class TermsWithScoreCollector extends SimpleCollector {
|
|||
TermsWithScoreCollector(String field, ScoreMode scoreMode) {
|
||||
this.field = field;
|
||||
this.scoreMode = scoreMode;
|
||||
if (scoreMode == ScoreMode.Min) {
|
||||
Arrays.fill(scoreSums, Float.POSITIVE_INFINITY);
|
||||
} else if (scoreMode == ScoreMode.Max) {
|
||||
Arrays.fill(scoreSums, Float.NEGATIVE_INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
public BytesRefHash getCollectedTerms() {
|
||||
|
@ -98,7 +104,13 @@ abstract class TermsWithScoreCollector extends SimpleCollector {
|
|||
ord = -ord - 1;
|
||||
} else {
|
||||
if (ord >= scoreSums.length) {
|
||||
int begin = scoreSums.length;
|
||||
scoreSums = ArrayUtil.grow(scoreSums);
|
||||
if (scoreMode == ScoreMode.Min) {
|
||||
Arrays.fill(scoreSums, begin, scoreSums.length, Float.POSITIVE_INFINITY);
|
||||
} else if (scoreMode == ScoreMode.Max) {
|
||||
Arrays.fill(scoreSums, begin, scoreSums.length, Float.NEGATIVE_INFINITY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,10 +123,16 @@ abstract class TermsWithScoreCollector extends SimpleCollector {
|
|||
case Total:
|
||||
scoreSums[ord] = scoreSums[ord] + current;
|
||||
break;
|
||||
case Min:
|
||||
if (current < existing) {
|
||||
scoreSums[ord] = current;
|
||||
}
|
||||
break;
|
||||
case Max:
|
||||
if (current > existing) {
|
||||
scoreSums[ord] = current;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -187,7 +205,13 @@ abstract class TermsWithScoreCollector extends SimpleCollector {
|
|||
termID = -termID - 1;
|
||||
} else {
|
||||
if (termID >= scoreSums.length) {
|
||||
int begin = scoreSums.length;
|
||||
scoreSums = ArrayUtil.grow(scoreSums);
|
||||
if (scoreMode == ScoreMode.Min) {
|
||||
Arrays.fill(scoreSums, begin, scoreSums.length, Float.POSITIVE_INFINITY);
|
||||
} else if (scoreMode == ScoreMode.Max) {
|
||||
Arrays.fill(scoreSums, begin, scoreSums.length, Float.NEGATIVE_INFINITY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -195,8 +219,12 @@ abstract class TermsWithScoreCollector extends SimpleCollector {
|
|||
case Total:
|
||||
scoreSums[termID] += scorer.score();
|
||||
break;
|
||||
case Min:
|
||||
scoreSums[termID] = Math.min(scoreSums[termID], scorer.score());
|
||||
break;
|
||||
case Max:
|
||||
scoreSums[termID] = Math.max(scoreSums[termID], scorer.score());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,12 +17,6 @@ package org.apache.lucene.search.join;
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Locale;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
|
@ -40,6 +34,12 @@ import org.apache.lucene.util.BitDocIdSet;
|
|||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Locale;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* This query requires that you index
|
||||
* children and parent docs as a single block, using the
|
||||
|
@ -297,6 +297,7 @@ public class ToParentBlockJoinQuery extends Query {
|
|||
|
||||
float totalScore = 0;
|
||||
float maxScore = Float.NEGATIVE_INFINITY;
|
||||
float minScore = Float.POSITIVE_INFINITY;
|
||||
|
||||
childDocUpto = 0;
|
||||
parentFreq = 0;
|
||||
|
@ -320,6 +321,7 @@ public class ToParentBlockJoinQuery extends Query {
|
|||
pendingChildScores[childDocUpto] = childScore;
|
||||
}
|
||||
maxScore = Math.max(childScore, maxScore);
|
||||
minScore = Math.min(childFreq, minScore);
|
||||
totalScore += childScore;
|
||||
parentFreq += childFreq;
|
||||
}
|
||||
|
@ -340,6 +342,9 @@ public class ToParentBlockJoinQuery extends Query {
|
|||
case Max:
|
||||
parentScore = maxScore;
|
||||
break;
|
||||
case Min:
|
||||
parentScore = minScore;
|
||||
break;
|
||||
case Total:
|
||||
parentScore = totalScore;
|
||||
break;
|
||||
|
|
|
@ -17,13 +17,6 @@ package org.apache.lucene.search.join;
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
|
@ -32,7 +25,6 @@ import org.apache.lucene.document.NumericDocValuesField;
|
|||
import org.apache.lucene.document.SortedDocValuesField;
|
||||
import org.apache.lucene.document.StoredField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
|
@ -40,31 +32,13 @@ import org.apache.lucene.index.LeafReaderContext;
|
|||
import org.apache.lucene.index.LogDocMergePolicy;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.NoMergePolicy;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.ReaderUtil;
|
||||
import org.apache.lucene.index.StoredDocument;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.*;
|
||||
import org.apache.lucene.search.BooleanClause.Occur;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.FieldDoc;
|
||||
import org.apache.lucene.search.Filter;
|
||||
import org.apache.lucene.search.FilteredQuery;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.MultiTermQuery;
|
||||
import org.apache.lucene.search.NumericRangeQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.QueryUtils;
|
||||
import org.apache.lucene.search.QueryWrapperFilter;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.search.grouping.GroupDocs;
|
||||
import org.apache.lucene.search.grouping.TopGroups;
|
||||
import org.apache.lucene.store.Directory;
|
||||
|
@ -76,6 +50,13 @@ import org.apache.lucene.util.LuceneTestCase;
|
|||
import org.apache.lucene.util.NumericUtils;
|
||||
import org.apache.lucene.util.TestUtil;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
public class TestBlockJoin extends LuceneTestCase {
|
||||
|
||||
// One resume...
|
||||
|
@ -401,7 +382,7 @@ public class TestBlockJoin extends LuceneTestCase {
|
|||
final List<LeafReaderContext> leaves = reader.leaves();
|
||||
final int subIndex = ReaderUtil.subIndex(childDocID, leaves);
|
||||
final LeafReaderContext leaf = leaves.get(subIndex);
|
||||
final BitSet bits = (BitSet) parents.getDocIdSet(leaf).bits();
|
||||
final BitSet bits = parents.getDocIdSet(leaf).bits();
|
||||
return leaf.reader().document(bits.nextSetBit(childDocID - leaf.docBase));
|
||||
}
|
||||
|
||||
|
@ -722,18 +703,8 @@ public class TestBlockJoin extends LuceneTestCase {
|
|||
random().nextBoolean() ? BooleanClause.Occur.MUST : BooleanClause.Occur.MUST_NOT);
|
||||
}
|
||||
|
||||
final int x = random().nextInt(4);
|
||||
final ScoreMode agg;
|
||||
if (x == 0) {
|
||||
agg = ScoreMode.None;
|
||||
} else if (x == 1) {
|
||||
agg = ScoreMode.Max;
|
||||
} else if (x == 2) {
|
||||
agg = ScoreMode.Total;
|
||||
} else {
|
||||
agg = ScoreMode.Avg;
|
||||
}
|
||||
|
||||
final ScoreMode agg = ScoreMode.values()[random().nextInt(ScoreMode.values().length)];
|
||||
final ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, agg);
|
||||
|
||||
// To run against the block-join index:
|
||||
|
|
|
@ -22,8 +22,10 @@ import org.apache.lucene.analysis.MockAnalyzer;
|
|||
import org.apache.lucene.analysis.MockTokenizer;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.document.SortedDocValuesField;
|
||||
import org.apache.lucene.document.SortedSetDocValuesField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.document.TextField;
|
||||
import org.apache.lucene.index.BinaryDocValues;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
|
@ -34,6 +36,7 @@ import org.apache.lucene.index.LeafReaderContext;
|
|||
import org.apache.lucene.index.MultiDocValues;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.NoMergePolicy;
|
||||
import org.apache.lucene.index.NumericDocValues;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.SlowCompositeReaderWrapper;
|
||||
|
@ -46,6 +49,8 @@ import org.apache.lucene.search.BooleanClause;
|
|||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.FieldValueQuery;
|
||||
import org.apache.lucene.search.FilterScorer;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.MultiCollector;
|
||||
|
@ -56,6 +61,7 @@ import org.apache.lucene.search.SimpleCollector;
|
|||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopScoreDocCollector;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
|
@ -332,6 +338,127 @@ public class TestJoinUtil extends LuceneTestCase {
|
|||
context.close();
|
||||
}
|
||||
|
||||
public void testMinMaxScore() throws Exception {
|
||||
String priceField = "price";
|
||||
// FunctionQuery would be helpful, but join module doesn't depend on queries module.
|
||||
Query priceQuery = new Query() {
|
||||
|
||||
private final Query fieldQuery = new FieldValueQuery(priceField);
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException {
|
||||
Weight fieldWeight = fieldQuery.createWeight(searcher, false);
|
||||
return new Weight(this) {
|
||||
|
||||
@Override
|
||||
public void extractTerms(Set<Term> terms) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getValueForNormalization() throws IOException {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void normalize(float norm, float topLevelBoost) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Scorer scorer(LeafReaderContext context, Bits acceptDocs) throws IOException {
|
||||
Scorer fieldScorer = fieldWeight.scorer(context, acceptDocs);
|
||||
NumericDocValues price = context.reader().getNumericDocValues(priceField);
|
||||
return new FilterScorer(fieldScorer, this) {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return (float) price.get(in.docID());
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return fieldQuery.toString(field);
|
||||
}
|
||||
};
|
||||
|
||||
Directory dir = newDirectory();
|
||||
RandomIndexWriter iw = new RandomIndexWriter(
|
||||
random(),
|
||||
dir,
|
||||
newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false))
|
||||
);
|
||||
|
||||
Map<String, Float> lowestScoresPerParent = new HashMap<>();
|
||||
Map<String, Float> highestScoresPerParent = new HashMap<>();
|
||||
int numParents = RandomInts.randomIntBetween(random(), 16, 64);
|
||||
for (int p = 0; p < numParents; p++) {
|
||||
String parentId = Integer.toString(p);
|
||||
Document parentDoc = new Document();
|
||||
parentDoc.add(new StringField("id", parentId, Field.Store.YES));
|
||||
parentDoc.add(new StringField("type", "to", Field.Store.NO));
|
||||
parentDoc.add(new SortedDocValuesField("join_field", new BytesRef(parentId)));
|
||||
iw.addDocument(parentDoc);
|
||||
int numChildren = RandomInts.randomIntBetween(random(), 2, 16);
|
||||
int lowest = Integer.MAX_VALUE;
|
||||
int highest = Integer.MIN_VALUE;
|
||||
for (int c = 0; c < numChildren; c++) {
|
||||
String childId = Integer.toString(p + c);
|
||||
Document childDoc = new Document();
|
||||
childDoc.add(new StringField("id", childId, Field.Store.YES));
|
||||
parentDoc.add(new StringField("type", "from", Field.Store.NO));
|
||||
childDoc.add(new SortedDocValuesField("join_field", new BytesRef(parentId)));
|
||||
int price = random().nextInt(1000);
|
||||
childDoc.add(new NumericDocValuesField(priceField, price));
|
||||
iw.addDocument(childDoc);
|
||||
lowest = Math.min(lowest, price);
|
||||
highest = Math.max(highest, price);
|
||||
}
|
||||
lowestScoresPerParent.put(parentId, (float) lowest);
|
||||
highestScoresPerParent.put(parentId, (float) highest);
|
||||
}
|
||||
iw.close();
|
||||
|
||||
|
||||
IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir));
|
||||
SortedDocValues[] values = new SortedDocValues[searcher.getIndexReader().leaves().size()];
|
||||
for (LeafReaderContext leadContext : searcher.getIndexReader().leaves()) {
|
||||
values[leadContext.ord] = DocValues.getSorted(leadContext.reader(), "join_field");
|
||||
}
|
||||
MultiDocValues.OrdinalMap ordinalMap = MultiDocValues.OrdinalMap.build(
|
||||
searcher.getIndexReader().getCoreCacheKey(), values, PackedInts.DEFAULT
|
||||
);
|
||||
BooleanQuery fromQuery = new BooleanQuery();
|
||||
fromQuery.add(priceQuery, BooleanClause.Occur.MUST);
|
||||
Query toQuery = new TermQuery(new Term("type", "to"));
|
||||
Query joinQuery = JoinUtil.createJoinQuery("join_field", fromQuery, toQuery, searcher, ScoreMode.Min, ordinalMap);
|
||||
TopDocs topDocs = searcher.search(joinQuery, numParents);
|
||||
assertEquals(numParents, topDocs.totalHits);
|
||||
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
|
||||
ScoreDoc scoreDoc = topDocs.scoreDocs[i];
|
||||
String id = SlowCompositeReaderWrapper.wrap(searcher.getIndexReader()).document(scoreDoc.doc).get("id");
|
||||
assertEquals(lowestScoresPerParent.get(id), scoreDoc.score, 0f);
|
||||
}
|
||||
|
||||
joinQuery = JoinUtil.createJoinQuery("join_field", fromQuery, toQuery, searcher, ScoreMode.Max, ordinalMap);
|
||||
topDocs = searcher.search(joinQuery, numParents);
|
||||
assertEquals(numParents, topDocs.totalHits);
|
||||
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
|
||||
ScoreDoc scoreDoc = topDocs.scoreDocs[i];
|
||||
String id = SlowCompositeReaderWrapper.wrap(searcher.getIndexReader()).document(scoreDoc.doc).get("id");
|
||||
assertEquals(highestScoresPerParent.get(id), scoreDoc.score, 0f);
|
||||
}
|
||||
|
||||
searcher.getIndexReader().close();
|
||||
dir.close();
|
||||
}
|
||||
|
||||
// TermsWithScoreCollector.MV.Avg forgets to grow beyond TermsWithScoreCollector.INITIAL_ARRAY_SIZE
|
||||
public void testOverflowTermsWithScoreCollector() throws Exception {
|
||||
test300spartans(true, ScoreMode.Avg);
|
||||
|
@ -382,7 +509,6 @@ public class TestJoinUtil extends LuceneTestCase {
|
|||
assertEquals(1, result.totalHits);
|
||||
assertEquals(0, result.scoreDocs[0].doc);
|
||||
|
||||
|
||||
indexSearcher.getIndexReader().close();
|
||||
dir.close();
|
||||
}
|
||||
|
@ -1073,26 +1199,32 @@ public class TestJoinUtil extends LuceneTestCase {
|
|||
|
||||
private static class JoinScore {
|
||||
|
||||
float maxScore;
|
||||
float minScore = Float.POSITIVE_INFINITY;
|
||||
float maxScore = Float.NEGATIVE_INFINITY;
|
||||
float total;
|
||||
int count;
|
||||
|
||||
void addScore(float score) {
|
||||
total += score;
|
||||
if (score > maxScore) {
|
||||
maxScore = score;
|
||||
}
|
||||
if (score < minScore) {
|
||||
minScore = score;
|
||||
}
|
||||
total += score;
|
||||
count++;
|
||||
}
|
||||
|
||||
float score(ScoreMode mode) {
|
||||
switch (mode) {
|
||||
case None:
|
||||
return 1.0f;
|
||||
return 1f;
|
||||
case Total:
|
||||
return total;
|
||||
case Avg:
|
||||
return total / count;
|
||||
case Min:
|
||||
return minScore;
|
||||
case Max:
|
||||
return maxScore;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue