Query DSL: custom_filters_score - add score_mode to control filters matching scoring, closes #1205.
This commit is contained in:
parent
4a886dbae1
commit
d93bc02309
|
@ -34,6 +34,7 @@ import org.elasticsearch.common.lucene.docset.DocSet;
|
|||
import org.elasticsearch.common.lucene.docset.DocSets;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Set;
|
||||
|
||||
|
@ -73,12 +74,16 @@ public class FiltersFunctionScoreQuery extends Query {
|
|||
}
|
||||
}
|
||||
|
||||
public static enum ScoreMode {First, Avg, Max, Total}
|
||||
|
||||
Query subQuery;
|
||||
final FilterFunction[] filterFunctions;
|
||||
final ScoreMode scoreMode;
|
||||
DocSet[] docSets;
|
||||
|
||||
public FiltersFunctionScoreQuery(Query subQuery, FilterFunction[] filterFunctions) {
|
||||
public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions) {
|
||||
this.subQuery = subQuery;
|
||||
this.scoreMode = scoreMode;
|
||||
this.filterFunctions = filterFunctions;
|
||||
this.docSets = new DocSet[filterFunctions.length];
|
||||
}
|
||||
|
@ -151,7 +156,7 @@ public class FiltersFunctionScoreQuery extends Query {
|
|||
filterFunction.function.setNextReader(reader);
|
||||
docSets[i] = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
|
||||
}
|
||||
return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, filterFunctions, docSets);
|
||||
return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, scoreMode, filterFunctions, docSets);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -161,16 +166,59 @@ public class FiltersFunctionScoreQuery extends Query {
|
|||
return subQueryExpl;
|
||||
}
|
||||
|
||||
for (FilterFunction filterFunction : filterFunctions) {
|
||||
DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
|
||||
if (docSet.get(doc)) {
|
||||
filterFunction.function.setNextReader(reader);
|
||||
Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl);
|
||||
float sc = getValue() * functionExplanation.getValue();
|
||||
Explanation res = new ComplexExplanation(true, sc, "custom score, product of:");
|
||||
res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString()));
|
||||
res.addDetail(functionExplanation);
|
||||
res.addDetail(new Explanation(getValue(), "queryBoost"));
|
||||
if (scoreMode == ScoreMode.First) {
|
||||
for (FilterFunction filterFunction : filterFunctions) {
|
||||
DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
|
||||
if (docSet.get(doc)) {
|
||||
filterFunction.function.setNextReader(reader);
|
||||
Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl);
|
||||
float sc = getValue() * functionExplanation.getValue();
|
||||
Explanation res = new ComplexExplanation(true, sc, "custom score, product of:");
|
||||
res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString()));
|
||||
res.addDetail(functionExplanation);
|
||||
res.addDetail(new Explanation(getValue(), "queryBoost"));
|
||||
return res;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int count = 0;
|
||||
float total = 0;
|
||||
float max = Float.NEGATIVE_INFINITY;
|
||||
ArrayList<Explanation> filtersExplanations = new ArrayList<Explanation>();
|
||||
for (FilterFunction filterFunction : filterFunctions) {
|
||||
DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
|
||||
if (docSet.get(doc)) {
|
||||
filterFunction.function.setNextReader(reader);
|
||||
Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl);
|
||||
float sc = functionExplanation.getValue();
|
||||
count++;
|
||||
total += sc;
|
||||
max = Math.max(sc, max);
|
||||
Explanation res = new ComplexExplanation(true, sc, "custom score, product of:");
|
||||
res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString()));
|
||||
res.addDetail(functionExplanation);
|
||||
res.addDetail(new Explanation(getValue(), "queryBoost"));
|
||||
filtersExplanations.add(res);
|
||||
}
|
||||
}
|
||||
if (count > 0) {
|
||||
float sc = 0;
|
||||
switch (scoreMode) {
|
||||
case Avg:
|
||||
sc = total / count;
|
||||
break;
|
||||
case Max:
|
||||
sc = max;
|
||||
break;
|
||||
case Total:
|
||||
sc = total;
|
||||
break;
|
||||
}
|
||||
sc *= getValue();
|
||||
Explanation res = new ComplexExplanation(true, sc, "custom score, score mode [" + scoreMode.toString().toLowerCase() + "]");
|
||||
for (Explanation explanation : filtersExplanations) {
|
||||
res.addDetail(explanation);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
@ -188,13 +236,15 @@ public class FiltersFunctionScoreQuery extends Query {
|
|||
private final float subQueryWeight;
|
||||
private final Scorer scorer;
|
||||
private final FilterFunction[] filterFunctions;
|
||||
private final ScoreMode scoreMode;
|
||||
private final DocSet[] docSets;
|
||||
|
||||
private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer,
|
||||
FilterFunction[] filterFunctions, DocSet[] docSets) throws IOException {
|
||||
ScoreMode scoreMode, FilterFunction[] filterFunctions, DocSet[] docSets) throws IOException {
|
||||
super(similarity);
|
||||
this.subQueryWeight = w.getValue();
|
||||
this.scorer = scorer;
|
||||
this.scoreMode = scoreMode;
|
||||
this.filterFunctions = filterFunctions;
|
||||
this.docSets = docSets;
|
||||
}
|
||||
|
@ -218,9 +268,36 @@ public class FiltersFunctionScoreQuery extends Query {
|
|||
public float score() throws IOException {
|
||||
int docId = scorer.docID();
|
||||
float score = scorer.score();
|
||||
for (int i = 0; i < filterFunctions.length; i++) {
|
||||
if (docSets[i].get(docId)) {
|
||||
return subQueryWeight * filterFunctions[i].function.score(docId, score);
|
||||
if (scoreMode == ScoreMode.First) {
|
||||
for (int i = 0; i < filterFunctions.length; i++) {
|
||||
if (docSets[i].get(docId)) {
|
||||
return subQueryWeight * filterFunctions[i].function.score(docId, score);
|
||||
}
|
||||
}
|
||||
} else if (scoreMode == ScoreMode.Max) {
|
||||
float maxScore = Float.NEGATIVE_INFINITY;
|
||||
for (int i = 0; i < filterFunctions.length; i++) {
|
||||
if (docSets[i].get(docId)) {
|
||||
maxScore = Math.max(filterFunctions[i].function.score(docId, score), maxScore);
|
||||
}
|
||||
}
|
||||
if (maxScore != Float.NEGATIVE_INFINITY) {
|
||||
score = maxScore;
|
||||
}
|
||||
} else { // Avg / Total
|
||||
float totalScore = 0.0f;
|
||||
int count = 0;
|
||||
for (int i = 0; i < filterFunctions.length; i++) {
|
||||
if (docSets[i].get(docId)) {
|
||||
totalScore += filterFunctions[i].function.score(docId, score);
|
||||
count++;
|
||||
}
|
||||
}
|
||||
if (count != 0) {
|
||||
score = totalScore;
|
||||
if (scoreMode == ScoreMode.Avg) {
|
||||
score /= count;
|
||||
}
|
||||
}
|
||||
}
|
||||
return subQueryWeight * score;
|
||||
|
|
|
@ -42,6 +42,8 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
|
|||
|
||||
private Map<String, Object> params = null;
|
||||
|
||||
private String scoreMode;
|
||||
|
||||
private ArrayList<FilterBuilder> filters = new ArrayList<FilterBuilder>();
|
||||
private ArrayList<String> scripts = new ArrayList<String>();
|
||||
private TFloatArrayList boosts = new TFloatArrayList();
|
||||
|
@ -64,6 +66,11 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
|
|||
return this;
|
||||
}
|
||||
|
||||
public CustomFiltersScoreQueryBuilder scoreMode(String scoreMode) {
|
||||
this.scoreMode = scoreMode;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the language of the script.
|
||||
*/
|
||||
|
@ -124,6 +131,10 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
|
|||
}
|
||||
builder.endArray();
|
||||
|
||||
if (scoreMode != null) {
|
||||
builder.field("score_mode", scoreMode);
|
||||
}
|
||||
|
||||
if (lang != null) {
|
||||
builder.field("lang", lang);
|
||||
}
|
||||
|
|
|
@ -58,6 +58,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
|
|||
String scriptLang = null;
|
||||
Map<String, Object> vars = null;
|
||||
|
||||
FiltersFunctionScoreQuery.ScoreMode scoreMode = FiltersFunctionScoreQuery.ScoreMode.First;
|
||||
ArrayList<Filter> filters = new ArrayList<Filter>();
|
||||
ArrayList<String> scripts = new ArrayList<String>();
|
||||
TFloatArrayList boosts = new TFloatArrayList();
|
||||
|
@ -110,6 +111,19 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
|
|||
scriptLang = parser.text();
|
||||
} else if ("boost".equals(currentFieldName)) {
|
||||
boost = parser.floatValue();
|
||||
} else if ("score_mode".equals(currentFieldName) || "scoreMode".equals(currentFieldName)) {
|
||||
String sScoreMode = parser.text();
|
||||
if ("avg".equals(sScoreMode)) {
|
||||
scoreMode = FiltersFunctionScoreQuery.ScoreMode.Avg;
|
||||
} else if ("max".equals(sScoreMode)) {
|
||||
scoreMode = FiltersFunctionScoreQuery.ScoreMode.Max;
|
||||
} else if ("total".equals(sScoreMode)) {
|
||||
scoreMode = FiltersFunctionScoreQuery.ScoreMode.Total;
|
||||
} else if ("first".equals(sScoreMode)) {
|
||||
scoreMode = FiltersFunctionScoreQuery.ScoreMode.First;
|
||||
} else {
|
||||
throw new QueryParsingException(parseContext.index(), "illegal score_mode for nested query [" + sScoreMode + "]");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -136,7 +150,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
|
|||
}
|
||||
filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), scoreFunction);
|
||||
}
|
||||
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, filterFunctions);
|
||||
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, scoreMode, filterFunctions);
|
||||
functionScoreQuery.setBoost(boost);
|
||||
return functionScoreQuery;
|
||||
}
|
||||
|
|
|
@ -148,10 +148,10 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
|
|||
@Test public void testCustomFiltersScore() throws Exception {
|
||||
client.admin().indices().prepareDelete().execute().actionGet();
|
||||
|
||||
client.prepareIndex("test", "type", "1").setSource("field", "value1").execute().actionGet();
|
||||
client.prepareIndex("test", "type", "2").setSource("field", "value2").execute().actionGet();
|
||||
client.prepareIndex("test", "type", "3").setSource("field", "value3").execute().actionGet();
|
||||
client.prepareIndex("test", "type", "4").setSource("field", "value4").execute().actionGet();
|
||||
client.prepareIndex("test", "type", "1").setSource("field", "value1", "color", "red").execute().actionGet();
|
||||
client.prepareIndex("test", "type", "2").setSource("field", "value2", "color", "blue").execute().actionGet();
|
||||
client.prepareIndex("test", "type", "3").setSource("field", "value3", "color", "red").execute().actionGet();
|
||||
client.prepareIndex("test", "type", "4").setSource("field", "value4", "color", "blue").execute().actionGet();
|
||||
|
||||
client.admin().indices().prepareRefresh().execute().actionGet();
|
||||
|
||||
|
@ -194,5 +194,50 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
|
|||
assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f));
|
||||
assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3")));
|
||||
assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f));
|
||||
|
||||
searchResponse = client.prepareSearch("test")
|
||||
.setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("total")
|
||||
.add(termFilter("field", "value4"), 2)
|
||||
.add(termFilter("field", "value1"), 3)
|
||||
.add(termFilter("color", "red"), 5))
|
||||
.setExplain(true)
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
|
||||
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
|
||||
assertThat(searchResponse.hits().getAt(0).id(), equalTo("1"));
|
||||
assertThat(searchResponse.hits().getAt(0).score(), equalTo(8.0f));
|
||||
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
|
||||
|
||||
searchResponse = client.prepareSearch("test")
|
||||
.setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("max")
|
||||
.add(termFilter("field", "value4"), 2)
|
||||
.add(termFilter("field", "value1"), 3)
|
||||
.add(termFilter("color", "red"), 5))
|
||||
.setExplain(true)
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
|
||||
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
|
||||
assertThat(searchResponse.hits().getAt(0).id(), equalTo("1"));
|
||||
assertThat(searchResponse.hits().getAt(0).score(), equalTo(5.0f));
|
||||
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
|
||||
|
||||
searchResponse = client.prepareSearch("test")
|
||||
.setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("avg")
|
||||
.add(termFilter("field", "value4"), 2)
|
||||
.add(termFilter("field", "value1"), 3)
|
||||
.add(termFilter("color", "red"), 5))
|
||||
.setExplain(true)
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
|
||||
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
|
||||
assertThat(searchResponse.hits().getAt(0).id(), equalTo("3"));
|
||||
assertThat(searchResponse.hits().getAt(0).score(), equalTo(5.0f));
|
||||
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
|
||||
assertThat(searchResponse.hits().getAt(1).id(), equalTo("1"));
|
||||
assertThat(searchResponse.hits().getAt(1).score(), equalTo(4.0f));
|
||||
logger.info("--> Hit[1] {} Explanation {}", searchResponse.hits().getAt(1).id(), searchResponse.hits().getAt(1).explanation());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue