Query DSL: custom_filters_score allow to associate boost on filter instead of script, closes #1204.
This commit is contained in:
parent
5845baa3e0
commit
4a886dbae1
|
@ -20,6 +20,7 @@
|
|||
package org.elasticsearch.index.query;
|
||||
|
||||
import org.elasticsearch.common.collect.Maps;
|
||||
import org.elasticsearch.common.trove.list.array.TFloatArrayList;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -43,6 +44,7 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
|
|||
|
||||
private ArrayList<FilterBuilder> filters = new ArrayList<FilterBuilder>();
|
||||
private ArrayList<String> scripts = new ArrayList<String>();
|
||||
private TFloatArrayList boosts = new TFloatArrayList();
|
||||
|
||||
public CustomFiltersScoreQueryBuilder(QueryBuilder queryBuilder) {
|
||||
this.queryBuilder = queryBuilder;
|
||||
|
@ -51,6 +53,14 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
|
|||
public CustomFiltersScoreQueryBuilder add(FilterBuilder filter, String script) {
|
||||
this.filters.add(filter);
|
||||
this.scripts.add(script);
|
||||
this.boosts.add(-1);
|
||||
return this;
|
||||
}
|
||||
|
||||
public CustomFiltersScoreQueryBuilder add(FilterBuilder filter, float boost) {
|
||||
this.filters.add(filter);
|
||||
this.scripts.add(null);
|
||||
this.boosts.add(boost);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -104,7 +114,12 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
|
|||
builder.startObject();
|
||||
builder.field("filter");
|
||||
filters.get(i).toXContent(builder, params);
|
||||
builder.field("script", scripts.get(i));
|
||||
String script = scripts.get(i);
|
||||
if (script != null) {
|
||||
builder.field("script", script);
|
||||
} else {
|
||||
builder.field("boost", boosts.get(i));
|
||||
}
|
||||
builder.endObject();
|
||||
}
|
||||
builder.endArray();
|
||||
|
|
|
@ -24,7 +24,10 @@ import org.apache.lucene.search.Query;
|
|||
import org.elasticsearch.ElasticSearchIllegalStateException;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.common.lucene.search.function.BoostScoreFunction;
|
||||
import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery;
|
||||
import org.elasticsearch.common.lucene.search.function.ScoreFunction;
|
||||
import org.elasticsearch.common.trove.list.array.TFloatArrayList;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.script.SearchScript;
|
||||
import org.elasticsearch.search.internal.SearchContext;
|
||||
|
@ -57,6 +60,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
|
|||
|
||||
ArrayList<Filter> filters = new ArrayList<Filter>();
|
||||
ArrayList<String> scripts = new ArrayList<String>();
|
||||
TFloatArrayList boosts = new TFloatArrayList();
|
||||
|
||||
String currentFieldName = null;
|
||||
XContentParser.Token token;
|
||||
|
@ -74,6 +78,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
|
|||
while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
|
||||
String script = null;
|
||||
Filter filter = null;
|
||||
float fboost = Float.NaN;
|
||||
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
|
||||
if (token == XContentParser.Token.FIELD_NAME) {
|
||||
currentFieldName = parser.currentName();
|
||||
|
@ -84,17 +89,20 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
|
|||
} else if (token.isValue()) {
|
||||
if ("script".equals(currentFieldName)) {
|
||||
script = parser.text();
|
||||
} else if ("boost".equals(currentFieldName)) {
|
||||
fboost = parser.floatValue();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (script == null) {
|
||||
throw new QueryParsingException(parseContext.index(), "[custom_filters_score] missing 'script' in filters array element");
|
||||
if (script == null && fboost == -1) {
|
||||
throw new QueryParsingException(parseContext.index(), "[custom_filters_score] missing 'script' or 'boost' in filters array element");
|
||||
}
|
||||
if (filter == null) {
|
||||
throw new QueryParsingException(parseContext.index(), "[custom_filters_score] missing 'filter' in filters array element");
|
||||
}
|
||||
filters.add(filter);
|
||||
scripts.add(script);
|
||||
boosts.add(fboost);
|
||||
}
|
||||
}
|
||||
} else if (token.isValue()) {
|
||||
|
@ -118,8 +126,15 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
|
|||
}
|
||||
FiltersFunctionScoreQuery.FilterFunction[] filterFunctions = new FiltersFunctionScoreQuery.FilterFunction[filters.size()];
|
||||
for (int i = 0; i < filterFunctions.length; i++) {
|
||||
SearchScript searchScript = context.scriptService().search(context.lookup(), scriptLang, scripts.get(i), vars);
|
||||
filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), new CustomScoreQueryParser.ScriptScoreFunction(searchScript));
|
||||
ScoreFunction scoreFunction;
|
||||
String script = scripts.get(i);
|
||||
if (script != null) {
|
||||
SearchScript searchScript = context.scriptService().search(context.lookup(), scriptLang, script, vars);
|
||||
scoreFunction = new CustomScoreQueryParser.ScriptScoreFunction(searchScript);
|
||||
} else {
|
||||
scoreFunction = new BoostScoreFunction(boosts.get(i));
|
||||
}
|
||||
filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), scoreFunction);
|
||||
}
|
||||
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, filterFunctions);
|
||||
functionScoreQuery.setBoost(boost);
|
||||
|
|
|
@ -174,5 +174,25 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
|
|||
assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f));
|
||||
assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3")));
|
||||
assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f));
|
||||
|
||||
searchResponse = client.prepareSearch("test")
|
||||
.setQuery(customFiltersScoreQuery(matchAllQuery())
|
||||
.add(termFilter("field", "value4"), 2)
|
||||
.add(termFilter("field", "value2"), 3))
|
||||
.setExplain(true)
|
||||
.execute().actionGet();
|
||||
|
||||
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
|
||||
|
||||
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
|
||||
assertThat(searchResponse.hits().getAt(0).id(), equalTo("2"));
|
||||
assertThat(searchResponse.hits().getAt(0).score(), equalTo(3.0f));
|
||||
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
|
||||
assertThat(searchResponse.hits().getAt(1).id(), equalTo("4"));
|
||||
assertThat(searchResponse.hits().getAt(1).score(), equalTo(2.0f));
|
||||
assertThat(searchResponse.hits().getAt(2).id(), anyOf(equalTo("1"), equalTo("3")));
|
||||
assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f));
|
||||
assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3")));
|
||||
assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue