LUCENE-6389: Added ScoreMode.Min that aggregates the lowest child score to the parent hit.

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1674622 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Martijn van Groningen 2015-04-19 13:14:13 +00:00
parent e699256bbb
commit 07beb24d93
8 changed files with 247 additions and 61 deletions

View File

@ -67,6 +67,9 @@ New Features
StoredFieldVisitor.stringField API to take UTF-8 byte[] instead of
String (Mike McCandless)
* LUCENE-6389: Added ScoreMode.Min that aggregates the lowest child score
to the parent hit. (Martijn van Groningen, Adrien Grand)
Optimizations
* LUCENE-6379: IndexWriter.deleteDocuments(Query...) now detects if

View File

@ -28,6 +28,7 @@ import org.apache.lucene.util.LongBitSet;
import org.apache.lucene.util.LongValues;
import java.io.IOException;
import java.util.Arrays;
abstract class GlobalOrdinalsWithScoreCollector implements Collector {
@ -44,7 +45,7 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
this.field = field;
this.ordinalMap = ordinalMap;
this.collectedOrds = new LongBitSet(valueCount);
this.scores = new Scores(valueCount);
this.scores = new Scores(valueCount, unset());
}
public LongBitSet getCollectorOrdinals() {
@ -57,6 +58,8 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
protected abstract void doScore(int globalOrd, float existingScore, float newScore);
protected abstract float unset();
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
SortedDocValues docTermOrds = DocValues.getSorted(context.reader(), field);
@ -128,6 +131,23 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
}
}
static final class Min extends GlobalOrdinalsWithScoreCollector {
public Min(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) {
super(field, ordinalMap, valueCount);
}
@Override
protected void doScore(int globalOrd, float existingScore, float newScore) {
scores.setScore(globalOrd, Math.min(existingScore, newScore));
}
@Override
protected float unset() {
return Float.POSITIVE_INFINITY;
}
}
static final class Max extends GlobalOrdinalsWithScoreCollector {
public Max(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) {
@ -139,6 +159,10 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
scores.setScore(globalOrd, Math.max(existingScore, newScore));
}
@Override
protected float unset() {
return Float.NEGATIVE_INFINITY;
}
}
static final class Sum extends GlobalOrdinalsWithScoreCollector {
@ -152,6 +176,10 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
scores.setScore(globalOrd, existingScore + newScore);
}
@Override
protected float unset() {
return 0f;
}
}
static final class Avg extends GlobalOrdinalsWithScoreCollector {
@ -173,6 +201,11 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
public float score(int globalOrdinal) {
return scores.getScore(globalOrdinal) / occurrences.getOccurrence(globalOrdinal);
}
@Override
protected float unset() {
return 0f;
}
}
// Because the global ordinal is directly used as a key to a score we should be somewhat smart about allocation
@ -188,10 +221,12 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
static final class Scores {
final float[][] blocks;
final float unset;
private Scores(long valueCount) {
private Scores(long valueCount, float unset) {
long blockSize = valueCount + arraySize - 1;
blocks = new float[(int) ((blockSize) / arraySize)][];
this.unset = unset;
}
public void setScore(int globalOrdinal, float score) {
@ -200,6 +235,9 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
float[] scores = blocks[block];
if (scores == null) {
blocks[block] = scores = new float[arraySize];
if (unset != 0f) {
Arrays.fill(scores, unset);
}
}
scores[offset] = score;
}
@ -212,7 +250,7 @@ abstract class GlobalOrdinalsWithScoreCollector implements Collector {
if (scores != null) {
score = scores[offset];
} else {
score = 0f;
score = unset;
}
return score;
}

View File

@ -79,6 +79,7 @@ public final class JoinUtil {
return new TermsQuery(toField, fromQuery, termsCollector.getCollectorTerms());
case Total:
case Max:
case Min:
case Avg:
TermsWithScoreCollector termsWithScoreCollector =
TermsWithScoreCollector.create(fromField, multipleValuesPerDocument, scoreMode);
@ -144,10 +145,10 @@ public final class JoinUtil {
valueCount = ordinalMap.getValueCount();
}
Query rewrittenFromQuery = searcher.rewrite(fromQuery);
final Query rewrittenFromQuery = searcher.rewrite(fromQuery);
if (scoreMode == ScoreMode.None) {
GlobalOrdinalsCollector globalOrdinalsCollector = new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount);
searcher.search(fromQuery, globalOrdinalsCollector);
searcher.search(rewrittenFromQuery, globalOrdinalsCollector);
return new GlobalOrdinalsQuery(globalOrdinalsCollector.getCollectorOrdinals(), joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader);
}
@ -156,6 +157,9 @@ public final class JoinUtil {
case Total:
globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount);
break;
case Min:
globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Min(joinField, ordinalMap, valueCount);
break;
case Max:
globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount);
break;
@ -165,7 +169,7 @@ public final class JoinUtil {
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode));
}
searcher.search(fromQuery, globalOrdinalsWithScoreCollector);
searcher.search(rewrittenFromQuery, globalOrdinalsWithScoreCollector);
return new GlobalOrdinalsWithScoreQuery(globalOrdinalsWithScoreCollector, joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader);
}

View File

@ -40,6 +40,11 @@ public enum ScoreMode {
/**
* Parent hit's score is the sum of all child scores.
*/
Total
Total,
/**
* Parent hit's score is the min of all child scores.
*/
Min
}

View File

@ -17,20 +17,21 @@ package org.apache.lucene.search.join;
* limitations under the License.
*/
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRefHash;
import java.io.IOException;
import java.util.Arrays;
abstract class TermsWithScoreCollector extends SimpleCollector {
private final static int INITIAL_ARRAY_SIZE = 256;
private final static int INITIAL_ARRAY_SIZE = 0;
final String field;
final BytesRefHash collectedTerms = new BytesRefHash();
@ -42,6 +43,11 @@ abstract class TermsWithScoreCollector extends SimpleCollector {
TermsWithScoreCollector(String field, ScoreMode scoreMode) {
this.field = field;
this.scoreMode = scoreMode;
if (scoreMode == ScoreMode.Min) {
Arrays.fill(scoreSums, Float.POSITIVE_INFINITY);
} else if (scoreMode == ScoreMode.Max) {
Arrays.fill(scoreSums, Float.NEGATIVE_INFINITY);
}
}
public BytesRefHash getCollectedTerms() {
@ -98,7 +104,13 @@ abstract class TermsWithScoreCollector extends SimpleCollector {
ord = -ord - 1;
} else {
if (ord >= scoreSums.length) {
int begin = scoreSums.length;
scoreSums = ArrayUtil.grow(scoreSums);
if (scoreMode == ScoreMode.Min) {
Arrays.fill(scoreSums, begin, scoreSums.length, Float.POSITIVE_INFINITY);
} else if (scoreMode == ScoreMode.Max) {
Arrays.fill(scoreSums, begin, scoreSums.length, Float.NEGATIVE_INFINITY);
}
}
}
@ -111,10 +123,16 @@ abstract class TermsWithScoreCollector extends SimpleCollector {
case Total:
scoreSums[ord] = scoreSums[ord] + current;
break;
case Min:
if (current < existing) {
scoreSums[ord] = current;
}
break;
case Max:
if (current > existing) {
scoreSums[ord] = current;
}
break;
}
}
}
@ -187,7 +205,13 @@ abstract class TermsWithScoreCollector extends SimpleCollector {
termID = -termID - 1;
} else {
if (termID >= scoreSums.length) {
int begin = scoreSums.length;
scoreSums = ArrayUtil.grow(scoreSums);
if (scoreMode == ScoreMode.Min) {
Arrays.fill(scoreSums, begin, scoreSums.length, Float.POSITIVE_INFINITY);
} else if (scoreMode == ScoreMode.Max) {
Arrays.fill(scoreSums, begin, scoreSums.length, Float.NEGATIVE_INFINITY);
}
}
}
@ -195,8 +219,12 @@ abstract class TermsWithScoreCollector extends SimpleCollector {
case Total:
scoreSums[termID] += scorer.score();
break;
case Min:
scoreSums[termID] = Math.min(scoreSums[termID], scorer.score());
break;
case Max:
scoreSums[termID] = Math.max(scoreSums[termID], scorer.score());
break;
}
}
}

View File

@ -17,12 +17,6 @@ package org.apache.lucene.search.join;
* limitations under the License.
*/
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.Locale;
import java.util.Set;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
@ -40,6 +34,12 @@ import org.apache.lucene.util.BitDocIdSet;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.Locale;
import java.util.Set;
/**
* This query requires that you index
* children and parent docs as a single block, using the
@ -297,6 +297,7 @@ public class ToParentBlockJoinQuery extends Query {
float totalScore = 0;
float maxScore = Float.NEGATIVE_INFINITY;
float minScore = Float.POSITIVE_INFINITY;
childDocUpto = 0;
parentFreq = 0;
@ -320,6 +321,7 @@ public class ToParentBlockJoinQuery extends Query {
pendingChildScores[childDocUpto] = childScore;
}
maxScore = Math.max(childScore, maxScore);
minScore = Math.min(childFreq, minScore);
totalScore += childScore;
parentFreq += childFreq;
}
@ -340,6 +342,9 @@ public class ToParentBlockJoinQuery extends Query {
case Max:
parentScore = maxScore;
break;
case Min:
parentScore = minScore;
break;
case Total:
parentScore = totalScore;
break;

View File

@ -17,13 +17,6 @@ package org.apache.lucene.search.join;
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
@ -32,7 +25,6 @@ import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
@ -40,31 +32,13 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.LogDocMergePolicy;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.StoredDocument;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.*;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Filter;
import org.apache.lucene.search.FilteredQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MultiTermQuery;
import org.apache.lucene.search.NumericRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryUtils;
import org.apache.lucene.search.QueryWrapperFilter;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.store.Directory;
@ -76,6 +50,13 @@ import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.NumericUtils;
import org.apache.lucene.util.TestUtil;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
public class TestBlockJoin extends LuceneTestCase {
// One resume...
@ -401,7 +382,7 @@ public class TestBlockJoin extends LuceneTestCase {
final List<LeafReaderContext> leaves = reader.leaves();
final int subIndex = ReaderUtil.subIndex(childDocID, leaves);
final LeafReaderContext leaf = leaves.get(subIndex);
final BitSet bits = (BitSet) parents.getDocIdSet(leaf).bits();
final BitSet bits = parents.getDocIdSet(leaf).bits();
return leaf.reader().document(bits.nextSetBit(childDocID - leaf.docBase));
}
@ -722,18 +703,8 @@ public class TestBlockJoin extends LuceneTestCase {
random().nextBoolean() ? BooleanClause.Occur.MUST : BooleanClause.Occur.MUST_NOT);
}
final int x = random().nextInt(4);
final ScoreMode agg;
if (x == 0) {
agg = ScoreMode.None;
} else if (x == 1) {
agg = ScoreMode.Max;
} else if (x == 2) {
agg = ScoreMode.Total;
} else {
agg = ScoreMode.Avg;
}
final ScoreMode agg = ScoreMode.values()[random().nextInt(ScoreMode.values().length)];
final ToParentBlockJoinQuery childJoinQuery = new ToParentBlockJoinQuery(childQuery, parentsFilter, agg);
// To run against the block-join index:

View File

@ -22,8 +22,10 @@ import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DirectoryReader;
@ -34,6 +36,7 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiDocValues;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
@ -46,6 +49,8 @@ import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FieldValueQuery;
import org.apache.lucene.search.FilterScorer;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MultiCollector;
@ -56,6 +61,7 @@ import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
@ -332,6 +338,127 @@ public class TestJoinUtil extends LuceneTestCase {
context.close();
}
public void testMinMaxScore() throws Exception {
String priceField = "price";
// FunctionQuery would be helpful, but join module doesn't depend on queries module.
Query priceQuery = new Query() {
private final Query fieldQuery = new FieldValueQuery(priceField);
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException {
Weight fieldWeight = fieldQuery.createWeight(searcher, false);
return new Weight(this) {
@Override
public void extractTerms(Set<Term> terms) {
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
return null;
}
@Override
public float getValueForNormalization() throws IOException {
return 0;
}
@Override
public void normalize(float norm, float topLevelBoost) {
}
@Override
public Scorer scorer(LeafReaderContext context, Bits acceptDocs) throws IOException {
Scorer fieldScorer = fieldWeight.scorer(context, acceptDocs);
NumericDocValues price = context.reader().getNumericDocValues(priceField);
return new FilterScorer(fieldScorer, this) {
@Override
public float score() throws IOException {
return (float) price.get(in.docID());
}
};
}
};
}
@Override
public String toString(String field) {
return fieldQuery.toString(field);
}
};
Directory dir = newDirectory();
RandomIndexWriter iw = new RandomIndexWriter(
random(),
dir,
newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false))
);
Map<String, Float> lowestScoresPerParent = new HashMap<>();
Map<String, Float> highestScoresPerParent = new HashMap<>();
int numParents = RandomInts.randomIntBetween(random(), 16, 64);
for (int p = 0; p < numParents; p++) {
String parentId = Integer.toString(p);
Document parentDoc = new Document();
parentDoc.add(new StringField("id", parentId, Field.Store.YES));
parentDoc.add(new StringField("type", "to", Field.Store.NO));
parentDoc.add(new SortedDocValuesField("join_field", new BytesRef(parentId)));
iw.addDocument(parentDoc);
int numChildren = RandomInts.randomIntBetween(random(), 2, 16);
int lowest = Integer.MAX_VALUE;
int highest = Integer.MIN_VALUE;
for (int c = 0; c < numChildren; c++) {
String childId = Integer.toString(p + c);
Document childDoc = new Document();
childDoc.add(new StringField("id", childId, Field.Store.YES));
parentDoc.add(new StringField("type", "from", Field.Store.NO));
childDoc.add(new SortedDocValuesField("join_field", new BytesRef(parentId)));
int price = random().nextInt(1000);
childDoc.add(new NumericDocValuesField(priceField, price));
iw.addDocument(childDoc);
lowest = Math.min(lowest, price);
highest = Math.max(highest, price);
}
lowestScoresPerParent.put(parentId, (float) lowest);
highestScoresPerParent.put(parentId, (float) highest);
}
iw.close();
IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir));
SortedDocValues[] values = new SortedDocValues[searcher.getIndexReader().leaves().size()];
for (LeafReaderContext leadContext : searcher.getIndexReader().leaves()) {
values[leadContext.ord] = DocValues.getSorted(leadContext.reader(), "join_field");
}
MultiDocValues.OrdinalMap ordinalMap = MultiDocValues.OrdinalMap.build(
searcher.getIndexReader().getCoreCacheKey(), values, PackedInts.DEFAULT
);
BooleanQuery fromQuery = new BooleanQuery();
fromQuery.add(priceQuery, BooleanClause.Occur.MUST);
Query toQuery = new TermQuery(new Term("type", "to"));
Query joinQuery = JoinUtil.createJoinQuery("join_field", fromQuery, toQuery, searcher, ScoreMode.Min, ordinalMap);
TopDocs topDocs = searcher.search(joinQuery, numParents);
assertEquals(numParents, topDocs.totalHits);
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
ScoreDoc scoreDoc = topDocs.scoreDocs[i];
String id = SlowCompositeReaderWrapper.wrap(searcher.getIndexReader()).document(scoreDoc.doc).get("id");
assertEquals(lowestScoresPerParent.get(id), scoreDoc.score, 0f);
}
joinQuery = JoinUtil.createJoinQuery("join_field", fromQuery, toQuery, searcher, ScoreMode.Max, ordinalMap);
topDocs = searcher.search(joinQuery, numParents);
assertEquals(numParents, topDocs.totalHits);
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
ScoreDoc scoreDoc = topDocs.scoreDocs[i];
String id = SlowCompositeReaderWrapper.wrap(searcher.getIndexReader()).document(scoreDoc.doc).get("id");
assertEquals(highestScoresPerParent.get(id), scoreDoc.score, 0f);
}
searcher.getIndexReader().close();
dir.close();
}
// TermsWithScoreCollector.MV.Avg forgets to grow beyond TermsWithScoreCollector.INITIAL_ARRAY_SIZE
public void testOverflowTermsWithScoreCollector() throws Exception {
test300spartans(true, ScoreMode.Avg);
@ -382,7 +509,6 @@ public class TestJoinUtil extends LuceneTestCase {
assertEquals(1, result.totalHits);
assertEquals(0, result.scoreDocs[0].doc);
indexSearcher.getIndexReader().close();
dir.close();
}
@ -1073,26 +1199,32 @@ public class TestJoinUtil extends LuceneTestCase {
private static class JoinScore {
float maxScore;
float minScore = Float.POSITIVE_INFINITY;
float maxScore = Float.NEGATIVE_INFINITY;
float total;
int count;
void addScore(float score) {
total += score;
if (score > maxScore) {
maxScore = score;
}
if (score < minScore) {
minScore = score;
}
total += score;
count++;
}
float score(ScoreMode mode) {
switch (mode) {
case None:
return 1.0f;
return 1f;
case Total:
return total;
case Avg:
return total / count;
case Min:
return minScore;
case Max:
return maxScore;
}