Query DSL: custom_filters_score - add score_mode to control filters matching scoring, closes #1205.
This commit is contained in:
parent
4a886dbae1
commit
d93bc02309
|
@ -34,6 +34,7 @@ import org.elasticsearch.common.lucene.docset.DocSet;
|
||||||
import org.elasticsearch.common.lucene.docset.DocSets;
|
import org.elasticsearch.common.lucene.docset.DocSets;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
@ -73,12 +74,16 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static enum ScoreMode {First, Avg, Max, Total}
|
||||||
|
|
||||||
Query subQuery;
|
Query subQuery;
|
||||||
final FilterFunction[] filterFunctions;
|
final FilterFunction[] filterFunctions;
|
||||||
|
final ScoreMode scoreMode;
|
||||||
DocSet[] docSets;
|
DocSet[] docSets;
|
||||||
|
|
||||||
public FiltersFunctionScoreQuery(Query subQuery, FilterFunction[] filterFunctions) {
|
public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions) {
|
||||||
this.subQuery = subQuery;
|
this.subQuery = subQuery;
|
||||||
|
this.scoreMode = scoreMode;
|
||||||
this.filterFunctions = filterFunctions;
|
this.filterFunctions = filterFunctions;
|
||||||
this.docSets = new DocSet[filterFunctions.length];
|
this.docSets = new DocSet[filterFunctions.length];
|
||||||
}
|
}
|
||||||
|
@ -151,7 +156,7 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
filterFunction.function.setNextReader(reader);
|
filterFunction.function.setNextReader(reader);
|
||||||
docSets[i] = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
|
docSets[i] = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
|
||||||
}
|
}
|
||||||
return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, filterFunctions, docSets);
|
return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, scoreMode, filterFunctions, docSets);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -161,16 +166,59 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
return subQueryExpl;
|
return subQueryExpl;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (FilterFunction filterFunction : filterFunctions) {
|
if (scoreMode == ScoreMode.First) {
|
||||||
DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
|
for (FilterFunction filterFunction : filterFunctions) {
|
||||||
if (docSet.get(doc)) {
|
DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
|
||||||
filterFunction.function.setNextReader(reader);
|
if (docSet.get(doc)) {
|
||||||
Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl);
|
filterFunction.function.setNextReader(reader);
|
||||||
float sc = getValue() * functionExplanation.getValue();
|
Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl);
|
||||||
Explanation res = new ComplexExplanation(true, sc, "custom score, product of:");
|
float sc = getValue() * functionExplanation.getValue();
|
||||||
res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString()));
|
Explanation res = new ComplexExplanation(true, sc, "custom score, product of:");
|
||||||
res.addDetail(functionExplanation);
|
res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString()));
|
||||||
res.addDetail(new Explanation(getValue(), "queryBoost"));
|
res.addDetail(functionExplanation);
|
||||||
|
res.addDetail(new Explanation(getValue(), "queryBoost"));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int count = 0;
|
||||||
|
float total = 0;
|
||||||
|
float max = Float.NEGATIVE_INFINITY;
|
||||||
|
ArrayList<Explanation> filtersExplanations = new ArrayList<Explanation>();
|
||||||
|
for (FilterFunction filterFunction : filterFunctions) {
|
||||||
|
DocSet docSet = DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader));
|
||||||
|
if (docSet.get(doc)) {
|
||||||
|
filterFunction.function.setNextReader(reader);
|
||||||
|
Explanation functionExplanation = filterFunction.function.explain(doc, subQueryExpl);
|
||||||
|
float sc = functionExplanation.getValue();
|
||||||
|
count++;
|
||||||
|
total += sc;
|
||||||
|
max = Math.max(sc, max);
|
||||||
|
Explanation res = new ComplexExplanation(true, sc, "custom score, product of:");
|
||||||
|
res.addDetail(new Explanation(1.0f, "match filter: " + filterFunction.filter.toString()));
|
||||||
|
res.addDetail(functionExplanation);
|
||||||
|
res.addDetail(new Explanation(getValue(), "queryBoost"));
|
||||||
|
filtersExplanations.add(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (count > 0) {
|
||||||
|
float sc = 0;
|
||||||
|
switch (scoreMode) {
|
||||||
|
case Avg:
|
||||||
|
sc = total / count;
|
||||||
|
break;
|
||||||
|
case Max:
|
||||||
|
sc = max;
|
||||||
|
break;
|
||||||
|
case Total:
|
||||||
|
sc = total;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sc *= getValue();
|
||||||
|
Explanation res = new ComplexExplanation(true, sc, "custom score, score mode [" + scoreMode.toString().toLowerCase() + "]");
|
||||||
|
for (Explanation explanation : filtersExplanations) {
|
||||||
|
res.addDetail(explanation);
|
||||||
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -188,13 +236,15 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
private final float subQueryWeight;
|
private final float subQueryWeight;
|
||||||
private final Scorer scorer;
|
private final Scorer scorer;
|
||||||
private final FilterFunction[] filterFunctions;
|
private final FilterFunction[] filterFunctions;
|
||||||
|
private final ScoreMode scoreMode;
|
||||||
private final DocSet[] docSets;
|
private final DocSet[] docSets;
|
||||||
|
|
||||||
private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer,
|
private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer,
|
||||||
FilterFunction[] filterFunctions, DocSet[] docSets) throws IOException {
|
ScoreMode scoreMode, FilterFunction[] filterFunctions, DocSet[] docSets) throws IOException {
|
||||||
super(similarity);
|
super(similarity);
|
||||||
this.subQueryWeight = w.getValue();
|
this.subQueryWeight = w.getValue();
|
||||||
this.scorer = scorer;
|
this.scorer = scorer;
|
||||||
|
this.scoreMode = scoreMode;
|
||||||
this.filterFunctions = filterFunctions;
|
this.filterFunctions = filterFunctions;
|
||||||
this.docSets = docSets;
|
this.docSets = docSets;
|
||||||
}
|
}
|
||||||
|
@ -218,9 +268,36 @@ public class FiltersFunctionScoreQuery extends Query {
|
||||||
public float score() throws IOException {
|
public float score() throws IOException {
|
||||||
int docId = scorer.docID();
|
int docId = scorer.docID();
|
||||||
float score = scorer.score();
|
float score = scorer.score();
|
||||||
for (int i = 0; i < filterFunctions.length; i++) {
|
if (scoreMode == ScoreMode.First) {
|
||||||
if (docSets[i].get(docId)) {
|
for (int i = 0; i < filterFunctions.length; i++) {
|
||||||
return subQueryWeight * filterFunctions[i].function.score(docId, score);
|
if (docSets[i].get(docId)) {
|
||||||
|
return subQueryWeight * filterFunctions[i].function.score(docId, score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (scoreMode == ScoreMode.Max) {
|
||||||
|
float maxScore = Float.NEGATIVE_INFINITY;
|
||||||
|
for (int i = 0; i < filterFunctions.length; i++) {
|
||||||
|
if (docSets[i].get(docId)) {
|
||||||
|
maxScore = Math.max(filterFunctions[i].function.score(docId, score), maxScore);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (maxScore != Float.NEGATIVE_INFINITY) {
|
||||||
|
score = maxScore;
|
||||||
|
}
|
||||||
|
} else { // Avg / Total
|
||||||
|
float totalScore = 0.0f;
|
||||||
|
int count = 0;
|
||||||
|
for (int i = 0; i < filterFunctions.length; i++) {
|
||||||
|
if (docSets[i].get(docId)) {
|
||||||
|
totalScore += filterFunctions[i].function.score(docId, score);
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (count != 0) {
|
||||||
|
score = totalScore;
|
||||||
|
if (scoreMode == ScoreMode.Avg) {
|
||||||
|
score /= count;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return subQueryWeight * score;
|
return subQueryWeight * score;
|
||||||
|
|
|
@ -42,6 +42,8 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
|
||||||
|
|
||||||
private Map<String, Object> params = null;
|
private Map<String, Object> params = null;
|
||||||
|
|
||||||
|
private String scoreMode;
|
||||||
|
|
||||||
private ArrayList<FilterBuilder> filters = new ArrayList<FilterBuilder>();
|
private ArrayList<FilterBuilder> filters = new ArrayList<FilterBuilder>();
|
||||||
private ArrayList<String> scripts = new ArrayList<String>();
|
private ArrayList<String> scripts = new ArrayList<String>();
|
||||||
private TFloatArrayList boosts = new TFloatArrayList();
|
private TFloatArrayList boosts = new TFloatArrayList();
|
||||||
|
@ -64,6 +66,11 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CustomFiltersScoreQueryBuilder scoreMode(String scoreMode) {
|
||||||
|
this.scoreMode = scoreMode;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the language of the script.
|
* Sets the language of the script.
|
||||||
*/
|
*/
|
||||||
|
@ -124,6 +131,10 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
|
||||||
}
|
}
|
||||||
builder.endArray();
|
builder.endArray();
|
||||||
|
|
||||||
|
if (scoreMode != null) {
|
||||||
|
builder.field("score_mode", scoreMode);
|
||||||
|
}
|
||||||
|
|
||||||
if (lang != null) {
|
if (lang != null) {
|
||||||
builder.field("lang", lang);
|
builder.field("lang", lang);
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,6 +58,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
|
||||||
String scriptLang = null;
|
String scriptLang = null;
|
||||||
Map<String, Object> vars = null;
|
Map<String, Object> vars = null;
|
||||||
|
|
||||||
|
FiltersFunctionScoreQuery.ScoreMode scoreMode = FiltersFunctionScoreQuery.ScoreMode.First;
|
||||||
ArrayList<Filter> filters = new ArrayList<Filter>();
|
ArrayList<Filter> filters = new ArrayList<Filter>();
|
||||||
ArrayList<String> scripts = new ArrayList<String>();
|
ArrayList<String> scripts = new ArrayList<String>();
|
||||||
TFloatArrayList boosts = new TFloatArrayList();
|
TFloatArrayList boosts = new TFloatArrayList();
|
||||||
|
@ -110,6 +111,19 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
|
||||||
scriptLang = parser.text();
|
scriptLang = parser.text();
|
||||||
} else if ("boost".equals(currentFieldName)) {
|
} else if ("boost".equals(currentFieldName)) {
|
||||||
boost = parser.floatValue();
|
boost = parser.floatValue();
|
||||||
|
} else if ("score_mode".equals(currentFieldName) || "scoreMode".equals(currentFieldName)) {
|
||||||
|
String sScoreMode = parser.text();
|
||||||
|
if ("avg".equals(sScoreMode)) {
|
||||||
|
scoreMode = FiltersFunctionScoreQuery.ScoreMode.Avg;
|
||||||
|
} else if ("max".equals(sScoreMode)) {
|
||||||
|
scoreMode = FiltersFunctionScoreQuery.ScoreMode.Max;
|
||||||
|
} else if ("total".equals(sScoreMode)) {
|
||||||
|
scoreMode = FiltersFunctionScoreQuery.ScoreMode.Total;
|
||||||
|
} else if ("first".equals(sScoreMode)) {
|
||||||
|
scoreMode = FiltersFunctionScoreQuery.ScoreMode.First;
|
||||||
|
} else {
|
||||||
|
throw new QueryParsingException(parseContext.index(), "illegal score_mode for nested query [" + sScoreMode + "]");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -136,7 +150,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
|
||||||
}
|
}
|
||||||
filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), scoreFunction);
|
filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), scoreFunction);
|
||||||
}
|
}
|
||||||
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, filterFunctions);
|
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, scoreMode, filterFunctions);
|
||||||
functionScoreQuery.setBoost(boost);
|
functionScoreQuery.setBoost(boost);
|
||||||
return functionScoreQuery;
|
return functionScoreQuery;
|
||||||
}
|
}
|
||||||
|
|
|
@ -148,10 +148,10 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
|
||||||
@Test public void testCustomFiltersScore() throws Exception {
|
@Test public void testCustomFiltersScore() throws Exception {
|
||||||
client.admin().indices().prepareDelete().execute().actionGet();
|
client.admin().indices().prepareDelete().execute().actionGet();
|
||||||
|
|
||||||
client.prepareIndex("test", "type", "1").setSource("field", "value1").execute().actionGet();
|
client.prepareIndex("test", "type", "1").setSource("field", "value1", "color", "red").execute().actionGet();
|
||||||
client.prepareIndex("test", "type", "2").setSource("field", "value2").execute().actionGet();
|
client.prepareIndex("test", "type", "2").setSource("field", "value2", "color", "blue").execute().actionGet();
|
||||||
client.prepareIndex("test", "type", "3").setSource("field", "value3").execute().actionGet();
|
client.prepareIndex("test", "type", "3").setSource("field", "value3", "color", "red").execute().actionGet();
|
||||||
client.prepareIndex("test", "type", "4").setSource("field", "value4").execute().actionGet();
|
client.prepareIndex("test", "type", "4").setSource("field", "value4", "color", "blue").execute().actionGet();
|
||||||
|
|
||||||
client.admin().indices().prepareRefresh().execute().actionGet();
|
client.admin().indices().prepareRefresh().execute().actionGet();
|
||||||
|
|
||||||
|
@ -194,5 +194,50 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
|
||||||
assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f));
|
assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f));
|
||||||
assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3")));
|
assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3")));
|
||||||
assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f));
|
assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f));
|
||||||
|
|
||||||
|
searchResponse = client.prepareSearch("test")
|
||||||
|
.setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("total")
|
||||||
|
.add(termFilter("field", "value4"), 2)
|
||||||
|
.add(termFilter("field", "value1"), 3)
|
||||||
|
.add(termFilter("color", "red"), 5))
|
||||||
|
.setExplain(true)
|
||||||
|
.execute().actionGet();
|
||||||
|
|
||||||
|
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
|
||||||
|
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
|
||||||
|
assertThat(searchResponse.hits().getAt(0).id(), equalTo("1"));
|
||||||
|
assertThat(searchResponse.hits().getAt(0).score(), equalTo(8.0f));
|
||||||
|
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
|
||||||
|
|
||||||
|
searchResponse = client.prepareSearch("test")
|
||||||
|
.setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("max")
|
||||||
|
.add(termFilter("field", "value4"), 2)
|
||||||
|
.add(termFilter("field", "value1"), 3)
|
||||||
|
.add(termFilter("color", "red"), 5))
|
||||||
|
.setExplain(true)
|
||||||
|
.execute().actionGet();
|
||||||
|
|
||||||
|
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
|
||||||
|
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
|
||||||
|
assertThat(searchResponse.hits().getAt(0).id(), equalTo("1"));
|
||||||
|
assertThat(searchResponse.hits().getAt(0).score(), equalTo(5.0f));
|
||||||
|
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
|
||||||
|
|
||||||
|
searchResponse = client.prepareSearch("test")
|
||||||
|
.setQuery(customFiltersScoreQuery(matchAllQuery()).scoreMode("avg")
|
||||||
|
.add(termFilter("field", "value4"), 2)
|
||||||
|
.add(termFilter("field", "value1"), 3)
|
||||||
|
.add(termFilter("color", "red"), 5))
|
||||||
|
.setExplain(true)
|
||||||
|
.execute().actionGet();
|
||||||
|
|
||||||
|
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
|
||||||
|
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
|
||||||
|
assertThat(searchResponse.hits().getAt(0).id(), equalTo("3"));
|
||||||
|
assertThat(searchResponse.hits().getAt(0).score(), equalTo(5.0f));
|
||||||
|
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
|
||||||
|
assertThat(searchResponse.hits().getAt(1).id(), equalTo("1"));
|
||||||
|
assertThat(searchResponse.hits().getAt(1).score(), equalTo(4.0f));
|
||||||
|
logger.info("--> Hit[1] {} Explanation {}", searchResponse.hits().getAt(1).id(), searchResponse.hits().getAt(1).explanation());
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue