Query DSL: custom_filters_score - add score_mode to control filters matching scoring, closes #1205.

This commit is contained in:
Shay Banon 2011-08-04 03:31:14 +03:00
parent 4a886dbae1
commit d93bc02309
4 changed files with 168 additions and 21 deletions

View File

@ -34,6 +34,7 @@ import org.elasticsearch.common.lucene.docset.DocSet;
import org.elasticsearch.common.lucene.docset.DocSets; import org.elasticsearch.common.lucene.docset.DocSets;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Set; import java.util.Set;
@ -73,12 +74,16 @@ public class FiltersFunctionScoreQuery extends Query {
} }
} }
public static enum ScoreMode {First, Avg, Max, Total}
Query subQuery; Query subQuery;
final FilterFunction[] filterFunctions; final FilterFunction[] filterFunctions;
final ScoreMode scoreMode;
DocSet[] docSets; DocSet[] docSets;
public FiltersFunctionScoreQuery(Query subQuery, FilterFunction[] filterFunctions) { public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions) {
this.subQuery = subQuery; this.subQuery = subQuery;
this.scoreMode = scoreMode;
this.filterFunctions = filterFunctions; this.filterFunctions = filterFunctions;
this.docSets = new DocSet[filterFunctions.length]; this.docSets = new DocSet[filterFunctions.length];
} }
@ -151,7 +156,7 @@ public class FiltersFunctionScoreQuery extends Query {
filterFunction.function.setNextReader(reader); filterFunction.function.setNextReader(reader);
docSets[i] = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader)); docSets[i] = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
} }
return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, filterFunctions, docSets); return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, scoreMode, filterFunctions, docSets);
} }
@Override @Override
@ -161,6 +166,7 @@ public class FiltersFunctionScoreQuery extends Query {
return subQueryExpl; return subQueryExpl;
} }
if (scoreMode == ScoreMode.First) {
for (FilterFunction filterFunction : filterFunctions) { for (FilterFunction filterFunction : filterFunctions) {
DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader)); DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
if (docSet.get(doc)) { if (docSet.get(doc)) {
@ -174,6 +180,48 @@ public class FiltersFunctionScoreQuery extends Query {
return res; return res;
} }
} }
} else {
int count = 0;
float total = 0;
float max = Float.NEGATIVE_INFINITY;
ArrayList<Explanation> filtersExplanations = new ArrayList<Explanation>();
for (FilterFunction filterFunction : filterFunctions) {
DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
if (docSet.get(doc)) {
filterFunction.function.setNextReader(reader);
Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl);
float sc = functionExplanation.getValue();
count++;
total += sc;
max = Math.max(sc, max);
Explanation res = new ComplexExplanation(true, sc, "custom score, product of:");
res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString()));
res.addDetail(functionExplanation);
res.addDetail(new Explanation(getValue(), "queryBoost"));
filtersExplanations.add(res);
}
}
if (count > 0) {
float sc = 0;
switch (scoreMode) {
case Avg:
sc = total / count;
break;
case Max:
sc = max;
break;
case Total:
sc = total;
break;
}
sc *= getValue();
Explanation res = new ComplexExplanation(true, sc, "custom score, score mode [" + scoreMode.toString().toLowerCase() + "]");
for (Explanation explanation : filtersExplanations) {
res.addDetail(explanation);
}
return res;
}
}
float sc = getValue() * subQueryExpl.getValue(); float sc = getValue() * subQueryExpl.getValue();
Explanation res = new ComplexExplanation(true, sc, "custom score, no filter match, product of:"); Explanation res = new ComplexExplanation(true, sc, "custom score, no filter match, product of:");
@ -188,13 +236,15 @@ public class FiltersFunctionScoreQuery extends Query {
private final float subQueryWeight; private final float subQueryWeight;
private final Scorer scorer; private final Scorer scorer;
private final FilterFunction[] filterFunctions; private final FilterFunction[] filterFunctions;
private final ScoreMode scoreMode;
private final DocSet[] docSets; private final DocSet[] docSets;
private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer, private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer,
FilterFunction[] filterFunctions, DocSet[] docSets) throws IOException { ScoreMode scoreMode, FilterFunction[] filterFunctions, DocSet[] docSets) throws IOException {
super(similarity); super(similarity);
this.subQueryWeight = w.getValue(); this.subQueryWeight = w.getValue();
this.scorer = scorer; this.scorer = scorer;
this.scoreMode = scoreMode;
this.filterFunctions = filterFunctions; this.filterFunctions = filterFunctions;
this.docSets = docSets; this.docSets = docSets;
} }
@ -218,11 +268,38 @@ public class FiltersFunctionScoreQuery extends Query {
public float score() throws IOException { public float score() throws IOException {
int docId = scorer.docID(); int docId = scorer.docID();
float score = scorer.score(); float score = scorer.score();
if (scoreMode == ScoreMode.First) {
for (int i = 0; i < filterFunctions.length; i++) { for (int i = 0; i < filterFunctions.length; i++) {
if (docSets[i].get(docId)) { if (docSets[i].get(docId)) {
return subQueryWeight * filterFunctions[i].function.score(docId, score); return subQueryWeight * filterFunctions[i].function.score(docId, score);
} }
} }
} else if (scoreMode == ScoreMode.Max) {
float maxScore = Float.NEGATIVE_INFINITY;
for (int i = 0; i < filterFunctions.length; i++) {
if (docSets[i].get(docId)) {
maxScore = Math.max(filterFunctions[i].function.score(docId, score), maxScore);
}
}
if (maxScore != Float.NEGATIVE_INFINITY) {
score = maxScore;
}
} else { // Avg / Total
float totalScore = 0.0f;
int count = 0;
for (int i = 0; i < filterFunctions.length; i++) {
if (docSets[i].get(docId)) {
totalScore += filterFunctions[i].function.score(docId, score);
count++;
}
}
if (count != 0) {
score = totalScore;
if (scoreMode == ScoreMode.Avg) {
score /= count;
}
}
}
return subQueryWeight * score; return subQueryWeight * score;
} }
} }

View File

@ -42,6 +42,8 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
private Map<String, Object> params = null; private Map<String, Object> params = null;
private String scoreMode;
private ArrayList<FilterBuilder> filters = new ArrayList<FilterBuilder>(); private ArrayList<FilterBuilder> filters = new ArrayList<FilterBuilder>();
private ArrayList<String> scripts = new ArrayList<String>(); private ArrayList<String> scripts = new ArrayList<String>();
private TFloatArrayList boosts = new TFloatArrayList(); private TFloatArrayList boosts = new TFloatArrayList();
@ -64,6 +66,11 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
return this; return this;
} }
public CustomFiltersScoreQueryBuilder scoreMode(String scoreMode) {
this.scoreMode = scoreMode;
return this;
}
/** /**
* Sets the language of the script. * Sets the language of the script.
*/ */
@ -124,6 +131,10 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
} }
builder.endArray(); builder.endArray();
if (scoreMode != null) {
builder.field("score_mode", scoreMode);
}
if (lang != null) { if (lang != null) {
builder.field("lang", lang); builder.field("lang", lang);
} }

View File

@ -58,6 +58,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
String scriptLang = null; String scriptLang = null;
Map<String, Object> vars = null; Map<String, Object> vars = null;
FiltersFunctionScoreQuery.ScoreMode scoreMode = FiltersFunctionScoreQuery.ScoreMode.First;
ArrayList<Filter> filters = new ArrayList<Filter>(); ArrayList<Filter> filters = new ArrayList<Filter>();
ArrayList<String> scripts = new ArrayList<String>(); ArrayList<String> scripts = new ArrayList<String>();
TFloatArrayList boosts = new TFloatArrayList(); TFloatArrayList boosts = new TFloatArrayList();
@ -110,6 +111,19 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
scriptLang = parser.text(); scriptLang = parser.text();
} else if ("boost".equals(currentFieldName)) { } else if ("boost".equals(currentFieldName)) {
boost = parser.floatValue(); boost = parser.floatValue();
} else if ("score_mode".equals(currentFieldName) || "scoreMode".equals(currentFieldName)) {
String sScoreMode = parser.text();
if ("avg".equals(sScoreMode)) {
scoreMode = FiltersFunctionScoreQuery.ScoreMode.Avg;
} else if ("max".equals(sScoreMode)) {
scoreMode = FiltersFunctionScoreQuery.ScoreMode.Max;
} else if ("total".equals(sScoreMode)) {
scoreMode = FiltersFunctionScoreQuery.ScoreMode.Total;
} else if ("first".equals(sScoreMode)) {
scoreMode = FiltersFunctionScoreQuery.ScoreMode.First;
} else {
throw new QueryParsingException(parseContext.index(), "illegal score_mode for nested query [" + sScoreMode + "]");
}
} }
} }
} }
@ -136,7 +150,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
} }
filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), scoreFunction); filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), scoreFunction);
} }
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, filterFunctions); FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, scoreMode, filterFunctions);
functionScoreQuery.setBoost(boost); functionScoreQuery.setBoost(boost);
return functionScoreQuery; return functionScoreQuery;
} }

View File

@ -148,10 +148,10 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
@Test public void testCustomFiltersScore() throws Exception { @Test public void testCustomFiltersScore() throws Exception {
client.admin().indices().prepareDelete().execute().actionGet(); client.admin().indices().prepareDelete().execute().actionGet();
client.prepareIndex("test", "type", "1").setSource("field", "value1").execute().actionGet(); client.prepareIndex("test", "type", "1").setSource("field", "value1", "color", "red").execute().actionGet();
client.prepareIndex("test", "type", "2").setSource("field", "value2").execute().actionGet(); client.prepareIndex("test", "type", "2").setSource("field", "value2", "color", "blue").execute().actionGet();
client.prepareIndex("test", "type", "3").setSource("field", "value3").execute().actionGet(); client.prepareIndex("test", "type", "3").setSource("field", "value3", "color", "red").execute().actionGet();
client.prepareIndex("test", "type", "4").setSource("field", "value4").execute().actionGet(); client.prepareIndex("test", "type", "4").setSource("field", "value4", "color", "blue").execute().actionGet();
client.admin().indices().prepareRefresh().execute().actionGet(); client.admin().indices().prepareRefresh().execute().actionGet();
@ -194,5 +194,50 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f)); assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f));
assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3"))); assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3")));
assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f)); assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f));
searchResponse = client.prepareSearch("test")
.setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("total")
.add(termFilter("field", "value4"), 2)
.add(termFilter("field", "value1"), 3)
.add(termFilter("color", "red"), 5))
.setExplain(true)
.execute().actionGet();
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
assertThat(searchResponse.hits().getAt(0).id(), equalTo("1"));
assertThat(searchResponse.hits().getAt(0).score(), equalTo(8.0f));
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
searchResponse = client.prepareSearch("test")
.setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("max")
.add(termFilter("field", "value4"), 2)
.add(termFilter("field", "value1"), 3)
.add(termFilter("color", "red"), 5))
.setExplain(true)
.execute().actionGet();
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
assertThat(searchResponse.hits().getAt(0).id(), equalTo("1"));
assertThat(searchResponse.hits().getAt(0).score(), equalTo(5.0f));
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
searchResponse = client.prepareSearch("test")
.setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("avg")
.add(termFilter("field", "value4"), 2)
.add(termFilter("field", "value1"), 3)
.add(termFilter("color", "red"), 5))
.setExplain(true)
.execute().actionGet();
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
assertThat(searchResponse.hits().getAt(0).id(), equalTo("3"));
assertThat(searchResponse.hits().getAt(0).score(), equalTo(5.0f));
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
assertThat(searchResponse.hits().getAt(1).id(), equalTo("1"));
assertThat(searchResponse.hits().getAt(1).score(), equalTo(4.0f));
logger.info("--> Hit[1] {} Explanation {}", searchResponse.hits().getAt(1).id(), searchResponse.hits().getAt(1).explanation());
} }
} }