Adding "min" score mode to parent-child queries

Support for "max", "sum", and "avg" already existed.
This commit is contained in:
Chris Earle 2014-09-17 16:33:05 -04:00
parent 5dc5849922
commit 9b84ad3c7b
8 changed files with 303 additions and 28 deletions

View File

@ -31,7 +31,7 @@ query the `total_hits` is always correct.
==== Scoring capabilities ==== Scoring capabilities
The `has_child` also has scoring support. The The `has_child` also has scoring support. The
supported score types are `max`, `sum`, `avg` or `none`. The default is supported score types are `min`, `max`, `sum`, `avg` or `none`. The default is
`none` and yields the same behaviour as in previous versions. If the `none` and yields the same behaviour as in previous versions. If the
score type is set to another value than `none`, the scores of all the score type is set to another value than `none`, the scores of all the
matching child documents are aggregated into the associated parent matching child documents are aggregated into the associated parent
@ -98,5 +98,3 @@ APIS, eg:
-------------------------------------------------- --------------------------------------------------
curl -XGET "http://localhost:9200/_stats/id_cache?pretty&human" curl -XGET "http://localhost:9200/_stats/id_cache?pretty&human"
-------------------------------------------------- --------------------------------------------------

View File

@ -176,6 +176,9 @@ public class ChildrenQuery extends Query {
try { try {
if (minChildren == 0 && maxChildren == 0 && scoreType != ScoreType.NONE) { if (minChildren == 0 && maxChildren == 0 && scoreType != ScoreType.NONE) {
switch (scoreType) { switch (scoreType) {
case MIN:
collector = new MinCollector(globalIfd, sc, parentType);
break;
case MAX: case MAX:
collector = new MaxCollector(globalIfd, sc, parentType); collector = new MaxCollector(globalIfd, sc, parentType);
break; break;
@ -186,6 +189,9 @@ public class ChildrenQuery extends Query {
} }
if (collector == null) { if (collector == null) {
switch (scoreType) { switch (scoreType) {
case MIN:
collector = new MinCountCollector(globalIfd, sc, parentType);
break;
case MAX: case MAX:
collector = new MaxCountCollector(globalIfd, sc, parentType); collector = new MaxCountCollector(globalIfd, sc, parentType);
break; break;
@ -468,6 +474,21 @@ public class ChildrenQuery extends Query {
} }
} }
private final static class MinCollector extends ParentScoreCollector {
private MinCollector(IndexParentChildFieldData globalIfd, SearchContext searchContext, String parentType) {
super(globalIfd, searchContext, parentType);
}
@Override
protected void existingParent(long parentIdx) throws IOException {
float currentScore = scorer.score();
if (currentScore < scores.get(parentIdx)) {
scores.set(parentIdx, currentScore);
}
}
}
private final static class MaxCountCollector extends ParentScoreCountCollector { private final static class MaxCountCollector extends ParentScoreCountCollector {
private MaxCountCollector(IndexParentChildFieldData globalIfd, SearchContext searchContext, String parentType) { private MaxCountCollector(IndexParentChildFieldData globalIfd, SearchContext searchContext, String parentType) {
@ -484,6 +505,22 @@ public class ChildrenQuery extends Query {
} }
} }
private final static class MinCountCollector extends ParentScoreCountCollector {
private MinCountCollector(IndexParentChildFieldData globalIfd, SearchContext searchContext, String parentType) {
super(globalIfd, searchContext, parentType);
}
@Override
protected void existingParent(long parentIdx) throws IOException {
float currentScore = scorer.score();
if (currentScore < scores.get(parentIdx)) {
scores.set(parentIdx, currentScore);
}
occurrences.increment(parentIdx, 1);
}
}
private final static class SumCountAndAvgCollector extends ParentScoreCountCollector { private final static class SumCountAndAvgCollector extends ParentScoreCountCollector {
SumCountAndAvgCollector(IndexParentChildFieldData globalIfd, SearchContext searchContext, String parentType) { SumCountAndAvgCollector(IndexParentChildFieldData globalIfd, SearchContext searchContext, String parentType) {

View File

@ -24,6 +24,11 @@ import org.elasticsearch.ElasticsearchIllegalArgumentException;
* Defines how scores from child documents are mapped into the parent document. * Defines how scores from child documents are mapped into the parent document.
*/ */
public enum ScoreType { public enum ScoreType {
/**
* Only the lowest score of all matching child documents is mapped into the
* parent.
*/
MIN,
/** /**
* Only the highest score of all matching child documents is mapped into the * Only the highest score of all matching child documents is mapped into the
* parent. * parent.
@ -50,6 +55,8 @@ public enum ScoreType {
public static ScoreType fromString(String type) { public static ScoreType fromString(String type) {
if ("none".equals(type)) { if ("none".equals(type)) {
return NONE; return NONE;
} else if ("min".equals(type)) {
return MIN;
} else if ("max".equals(type)) { } else if ("max".equals(type)) {
return MAX; return MAX;
} else if ("avg".equals(type)) { } else if ("avg".equals(type)) {

View File

@ -218,11 +218,15 @@ public class TopChildrenQuery extends Query {
parentDoc.docId = parentDocId; parentDoc.docId = parentDocId;
parentDoc.count = 1; parentDoc.count = 1;
parentDoc.maxScore = scoreDoc.score; parentDoc.maxScore = scoreDoc.score;
parentDoc.minScore = scoreDoc.score;
parentDoc.sumScores = scoreDoc.score; parentDoc.sumScores = scoreDoc.score;
readerParentDocs.put(parentDocId, parentDoc); readerParentDocs.put(parentDocId, parentDoc);
} else { } else {
parentDoc.count++; parentDoc.count++;
parentDoc.sumScores += scoreDoc.score; parentDoc.sumScores += scoreDoc.score;
if (scoreDoc.score < parentDoc.minScore) {
parentDoc.minScore = scoreDoc.score;
}
if (scoreDoc.score > parentDoc.maxScore) { if (scoreDoc.score > parentDoc.maxScore) {
parentDoc.maxScore = scoreDoc.score; parentDoc.maxScore = scoreDoc.score;
} }
@ -320,11 +324,19 @@ public class TopChildrenQuery extends Query {
public Scorer scorer(AtomicReaderContext context, Bits acceptDocs) throws IOException { public Scorer scorer(AtomicReaderContext context, Bits acceptDocs) throws IOException {
ParentDoc[] readerParentDocs = parentDocs.get(context.reader().getCoreCacheKey()); ParentDoc[] readerParentDocs = parentDocs.get(context.reader().getCoreCacheKey());
if (readerParentDocs != null) { if (readerParentDocs != null) {
if (scoreType == ScoreType.MAX) { if (scoreType == ScoreType.MIN) {
return new ParentScorer(this, readerParentDocs) { return new ParentScorer(this, readerParentDocs) {
@Override @Override
public float score() throws IOException { public float score() throws IOException {
assert doc.docId >= 0 || doc.docId < NO_MORE_DOCS; assert doc.docId >= 0 && doc.docId != NO_MORE_DOCS;
return doc.minScore;
}
};
} else if (scoreType == ScoreType.MAX) {
return new ParentScorer(this, readerParentDocs) {
@Override
public float score() throws IOException {
assert doc.docId >= 0 && doc.docId != NO_MORE_DOCS;
return doc.maxScore; return doc.maxScore;
} }
}; };
@ -332,7 +344,7 @@ public class TopChildrenQuery extends Query {
return new ParentScorer(this, readerParentDocs) { return new ParentScorer(this, readerParentDocs) {
@Override @Override
public float score() throws IOException { public float score() throws IOException {
assert doc.docId >= 0 || doc.docId < NO_MORE_DOCS; assert doc.docId >= 0 && doc.docId != NO_MORE_DOCS;
return doc.sumScores / doc.count; return doc.sumScores / doc.count;
} }
}; };
@ -340,7 +352,7 @@ public class TopChildrenQuery extends Query {
return new ParentScorer(this, readerParentDocs) { return new ParentScorer(this, readerParentDocs) {
@Override @Override
public float score() throws IOException { public float score() throws IOException {
assert doc.docId >= 0 || doc.docId < NO_MORE_DOCS; assert doc.docId >= 0 && doc.docId != NO_MORE_DOCS;
return doc.sumScores; return doc.sumScores;
} }
@ -412,6 +424,7 @@ public class TopChildrenQuery extends Query {
private static class ParentDoc { private static class ParentDoc {
public int docId; public int docId;
public int count; public int count;
public float minScore = Float.NaN;
public float maxScore = Float.NaN; public float maxScore = Float.NaN;
public float sumScores = 0; public float sumScores = 0;
} }

View File

@ -42,10 +42,17 @@ import static org.hamcrest.Matchers.equalTo;
@LuceneTestCase.SuppressCodecs(value = {"Lucene40", "Lucene3x"}) @LuceneTestCase.SuppressCodecs(value = {"Lucene40", "Lucene3x"})
public abstract class AbstractChildTests extends ElasticsearchSingleNodeLuceneTestCase { public abstract class AbstractChildTests extends ElasticsearchSingleNodeLuceneTestCase {
/**
* The name of the field within the child type that stores a score to use in test queries.
* <p />
* Its type is {@code double}.
*/
protected static String CHILD_SCORE_NAME = "childScore";
static SearchContext createSearchContext(String indexName, String parentType, String childType) throws IOException { static SearchContext createSearchContext(String indexName, String parentType, String childType) throws IOException {
IndexService indexService = createIndex(indexName); IndexService indexService = createIndex(indexName);
MapperService mapperService = indexService.mapperService(); MapperService mapperService = indexService.mapperService();
mapperService.merge(childType, new CompressedString(PutMappingRequest.buildFromSimplifiedDef(childType, "_parent", "type=" + parentType).string()), true); mapperService.merge(childType, new CompressedString(PutMappingRequest.buildFromSimplifiedDef(childType, "_parent", "type=" + parentType, CHILD_SCORE_NAME, "type=double").string()), true);
return createSearchContext(indexService); return createSearchContext(indexService);
} }

View File

@ -24,6 +24,7 @@ import com.carrotsearch.hppc.ObjectObjectOpenHashMap;
import com.carrotsearch.randomizedtesting.generators.RandomInts; import com.carrotsearch.randomizedtesting.generators.RandomInts;
import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.DoubleField;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField; import org.apache.lucene.document.StringField;
import org.apache.lucene.index.*; import org.apache.lucene.index.*;
@ -34,11 +35,17 @@ import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.lucene.search.NotFilter; import org.elasticsearch.common.lucene.search.NotFilter;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.common.lucene.search.XFilteredQuery; import org.elasticsearch.common.lucene.search.XFilteredQuery;
import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.index.engine.Engine; import org.elasticsearch.index.engine.Engine;
import org.elasticsearch.index.cache.fixedbitset.FixedBitSetFilter; import org.elasticsearch.index.cache.fixedbitset.FixedBitSetFilter;
import org.elasticsearch.index.fielddata.IndexNumericFieldData;
import org.elasticsearch.index.fielddata.plain.ParentChildIndexFieldData; import org.elasticsearch.index.fielddata.plain.ParentChildIndexFieldData;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.Uid; import org.elasticsearch.index.mapper.Uid;
import org.elasticsearch.index.mapper.internal.IdFieldMapper;
import org.elasticsearch.index.mapper.internal.ParentFieldMapper; import org.elasticsearch.index.mapper.internal.ParentFieldMapper;
import org.elasticsearch.index.mapper.internal.TypeFieldMapper; import org.elasticsearch.index.mapper.internal.TypeFieldMapper;
import org.elasticsearch.index.mapper.internal.UidFieldMapper; import org.elasticsearch.index.mapper.internal.UidFieldMapper;
@ -56,6 +63,9 @@ import java.util.NavigableMap;
import java.util.Random; import java.util.Random;
import java.util.TreeMap; import java.util.TreeMap;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class ChildrenQueryTests extends AbstractChildTests { public class ChildrenQueryTests extends AbstractChildTests {
@BeforeClass @BeforeClass
@ -128,7 +138,7 @@ public class ChildrenQueryTests extends AbstractChildTests {
String childValue = childValues[random().nextInt(childValues.length)]; String childValue = childValues[random().nextInt(childValues.length)];
document = new Document(); document = new Document();
document.add(new StringField(UidFieldMapper.NAME, Uid.createUid("child", Integer.toString(childDocId)), Field.Store.NO)); document.add(new StringField(UidFieldMapper.NAME, Uid.createUid("child", Integer.toString(childDocId++)), Field.Store.NO));
document.add(new StringField(TypeFieldMapper.NAME, "child", Field.Store.NO)); document.add(new StringField(TypeFieldMapper.NAME, "child", Field.Store.NO));
document.add(new StringField(ParentFieldMapper.NAME, Uid.createUid("parent", parent), Field.Store.NO)); document.add(new StringField(ParentFieldMapper.NAME, Uid.createUid("parent", parent), Field.Store.NO));
document.add(new StringField("field1", childValue, Field.Store.NO)); document.add(new StringField("field1", childValue, Field.Store.NO));
@ -264,4 +274,140 @@ public class ChildrenQueryTests extends AbstractChildTests {
directory.close(); directory.close();
} }
@Test
public void testMinScoreMode() throws IOException {
assertScoreType(ScoreType.MIN);
}
@Test
public void testMaxScoreMode() throws IOException {
assertScoreType(ScoreType.MAX);
}
@Test
public void testAvgScoreMode() throws IOException {
assertScoreType(ScoreType.AVG);
}
@Test
public void testSumScoreMode() throws IOException {
assertScoreType(ScoreType.SUM);
}
/**
* Assert that the {@code scoreType} operates as expected and parents are found in the expected order.
* <p />
* This will use the test index's parent/child types to create parents with multiple children. Each child will have
* a randomly generated scored stored in {@link #CHILD_SCORE_NAME}, which is used to score based on the
* {@code scoreType} by using a {@link MockScorer} to determine the expected scores.
* @param scoreType The score type to use within the query to score parents relative to their children.
* @throws IOException if any unexpected error occurs
*/
private void assertScoreType(ScoreType scoreType) throws IOException {
SearchContext context = SearchContext.current();
Directory directory = newDirectory();
IndexWriter writer = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
// calculates the expected score per parent
MockScorer scorer = new MockScorer(scoreType);
scorer.scores = new FloatArrayList(10);
// number of parents to generate
int parentDocs = scaledRandomIntBetween(2, 10);
// unique child ID
int childDocId = 0;
// Parent ID to expected score
Map<String, Float> parentScores = new TreeMap<>();
// Add a few random parents to ensure that the children's score is appropriately taken into account
for (int parentDocId = 0; parentDocId < parentDocs; ++parentDocId) {
String parent = Integer.toString(parentDocId);
// Create the parent
Document parentDocument = new Document();
parentDocument.add(new StringField(UidFieldMapper.NAME, Uid.createUid("parent", parent), Field.Store.YES));
parentDocument.add(new StringField(IdFieldMapper.NAME, parent, Field.Store.YES));
parentDocument.add(new StringField(TypeFieldMapper.NAME, "parent", Field.Store.NO));
// add the parent to the index
writer.addDocument(parentDocument);
int numChildDocs = scaledRandomIntBetween(1, 10);
// forget any parent's previous scores
scorer.scores.clear();
// associate children with the parent
for (int i = 0; i < numChildDocs; ++i) {
int childScore = random().nextInt(128);
Document childDocument = new Document();
childDocument.add(new StringField(UidFieldMapper.NAME, Uid.createUid("child", Integer.toString(childDocId++)), Field.Store.NO));
childDocument.add(new StringField(TypeFieldMapper.NAME, "child", Field.Store.NO));
// parent association:
childDocument.add(new StringField(ParentFieldMapper.NAME, Uid.createUid("parent", parent), Field.Store.NO));
childDocument.add(new DoubleField(CHILD_SCORE_NAME, childScore, Field.Store.NO));
// remember the score to be calculated
scorer.scores.add(childScore);
// add the associated child to the index
writer.addDocument(childDocument);
}
// this score that should be returned for this parent
parentScores.put(parent, scorer.score());
}
writer.commit();
IndexReader reader = DirectoryReader.open(writer, true);
IndexSearcher searcher = new IndexSearcher(reader);
// setup to read the parent/child map
Engine.SimpleSearcher engineSearcher = new Engine.SimpleSearcher(ChildrenQueryTests.class.getSimpleName(), searcher);
((TestSearchContext)context).setSearcher(new ContextIndexSearcher(context, engineSearcher));
ParentFieldMapper parentFieldMapper = context.mapperService().documentMapper("child").parentFieldMapper();
ParentChildIndexFieldData parentChildIndexFieldData = context.fieldData().getForField(parentFieldMapper);
FixedBitSetFilter parentFilter = wrap(new TermFilter(new Term(TypeFieldMapper.NAME, "parent")));
// child query that returns the score as the value of "childScore" for each child document,
// with the parent's score determined by the score type
FieldMapper fieldMapper = context.mapperService().smartNameFieldMapper(CHILD_SCORE_NAME);
IndexNumericFieldData fieldData = context.fieldData().getForField(fieldMapper);
FieldValueFactorFunction fieldScore = new FieldValueFactorFunction(CHILD_SCORE_NAME, 1, FieldValueFactorFunction.Modifier.NONE, fieldData);
Query childQuery = new FunctionScoreQuery(new FilteredQuery(Queries.newMatchAllQuery(), new TermFilter(new Term(TypeFieldMapper.NAME, "child"))), fieldScore);
// Perform the search for the documents using the selected score type
TopDocs docs =
searcher.search(
new ChildrenQuery(parentChildIndexFieldData, "parent", "child", parentFilter, childQuery, scoreType, 0, 0, parentDocs, null),
parentDocs);
assertThat("Expected all parents", docs.totalHits, is(parentDocs));
// score should be descending (just a sanity check)
float topScore = docs.scoreDocs[0].score;
// ensure each score is returned as expected
for (int i = 0; i < parentDocs; ++i) {
ScoreDoc scoreDoc = docs.scoreDocs[i];
// get the ID from the document to get its expected score; remove it so we cannot double-count it
float score = parentScores.remove(reader.document(scoreDoc.doc).get(IdFieldMapper.NAME));
// expect exact match
assertThat("Unexpected score", scoreDoc.score, is(score));
assertThat("Not descending", score, lessThanOrEqualTo(topScore));
// it had better keep descending
topScore = score;
}
reader.close();
writer.close();
directory.close();
}
} }

View File

@ -39,9 +39,20 @@ class MockScorer extends Scorer {
return 1.0f; return 1.0f;
} }
float aggregateScore = 0; float aggregateScore = 0;
for (int i = 0; i < scores.elementsCount; i++) {
// in the case of a min value, it can't start at 0 (the lowest score); in all cases, it doesn't hurt to use the
// first score, so we can safely use the first value by skipping it in the loop
if (scores.elementsCount != 0) {
aggregateScore = scores.buffer[0];
for (int i = 1; i < scores.elementsCount; i++) {
float score = scores.buffer[i]; float score = scores.buffer[i];
switch (scoreType) { switch (scoreType) {
case MIN:
if (aggregateScore > score) {
aggregateScore = score;
}
break;
case MAX: case MAX:
if (aggregateScore < score) { if (aggregateScore < score) {
aggregateScore = score; aggregateScore = score;
@ -57,6 +68,7 @@ class MockScorer extends Scorer {
if (scoreType == ScoreType.AVG) { if (scoreType == ScoreType.AVG) {
aggregateScore /= scores.elementsCount; aggregateScore /= scores.elementsCount;
} }
}
return aggregateScore; return aggregateScore;
} }

View File

@ -0,0 +1,55 @@
package org.elasticsearch.index.search.child;
import org.elasticsearch.ElasticsearchIllegalArgumentException;
import org.junit.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
/**
* Tests {@link ScoreType} to ensure backward compatibility of any changes.
*/
public class ScoreTypeTests {
@Test
public void minFromString() {
assertThat("fromString(min) != MIN", ScoreType.MIN, equalTo(ScoreType.fromString("min")));
}
@Test
public void maxFromString() {
assertThat("fromString(max) != MAX", ScoreType.MAX, equalTo(ScoreType.fromString("max")));
}
@Test
public void avgFromString() {
assertThat("fromString(avg) != AVG", ScoreType.AVG, equalTo(ScoreType.fromString("avg")));
}
@Test
public void sumFromString() {
assertThat("fromString(sum) != SUM", ScoreType.SUM, equalTo(ScoreType.fromString("sum")));
// allowed for consistency with ScoreMode.Total:
assertThat("fromString(total) != SUM", ScoreType.SUM, equalTo(ScoreType.fromString("total")));
}
@Test
public void noneFromString() {
assertThat("fromString(none) != NONE", ScoreType.NONE, equalTo(ScoreType.fromString("none")));
}
/**
* Should throw {@link ElasticsearchIllegalArgumentException} instead of NPE.
*/
@Test(expected = ElasticsearchIllegalArgumentException.class)
public void nullFromString_throwsException() {
ScoreType.fromString(null);
}
/**
* Failure should not change (and the value should never match anything...).
*/
@Test(expected = ElasticsearchIllegalArgumentException.class)
public void unrecognizedFromString_throwsException() {
ScoreType.fromString("unrecognized value");
}
}