Query DSL: custom_filters_score allow to associate boost on filter instead of script, closes #1204.

This commit is contained in:
Shay Banon 2011-08-04 02:50:58 +03:00
parent 5845baa3e0
commit 4a886dbae1
3 changed files with 55 additions and 5 deletions

View File

@ -20,6 +20,7 @@
package org.elasticsearch.index.query;
import org.elasticsearch.common.collect.Maps;
import org.elasticsearch.common.trove.list.array.TFloatArrayList;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
@ -43,6 +44,7 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
private ArrayList<FilterBuilder> filters = new ArrayList<FilterBuilder>();
private ArrayList<String> scripts = new ArrayList<String>();
private TFloatArrayList boosts = new TFloatArrayList();
public CustomFiltersScoreQueryBuilder(QueryBuilder queryBuilder) {
this.queryBuilder = queryBuilder;
@ -51,6 +53,14 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
public CustomFiltersScoreQueryBuilder add(FilterBuilder filter, String script) {
this.filters.add(filter);
this.scripts.add(script);
this.boosts.add(-1);
return this;
}
public CustomFiltersScoreQueryBuilder add(FilterBuilder filter, float boost) {
this.filters.add(filter);
this.scripts.add(null);
this.boosts.add(boost);
return this;
}
@ -104,7 +114,12 @@ public class CustomFiltersScoreQueryBuilder extends BaseQueryBuilder {
builder.startObject();
builder.field("filter");
filters.get(i).toXContent(builder, params);
builder.field("script", scripts.get(i));
String script = scripts.get(i);
if (script != null) {
builder.field("script", script);
} else {
builder.field("boost", boosts.get(i));
}
builder.endObject();
}
builder.endArray();

View File

@ -24,7 +24,10 @@ import org.apache.lucene.search.Query;
import org.elasticsearch.ElasticSearchIllegalStateException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.lucene.search.function.BoostScoreFunction;
import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.ScoreFunction;
import org.elasticsearch.common.trove.list.array.TFloatArrayList;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.script.SearchScript;
import org.elasticsearch.search.internal.SearchContext;
@ -57,6 +60,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
ArrayList<Filter> filters = new ArrayList<Filter>();
ArrayList<String> scripts = new ArrayList<String>();
TFloatArrayList boosts = new TFloatArrayList();
String currentFieldName = null;
XContentParser.Token token;
@ -74,6 +78,7 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
String script = null;
Filter filter = null;
float fboost = Float.NaN;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
@ -84,17 +89,20 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
} else if (token.isValue()) {
if ("script".equals(currentFieldName)) {
script = parser.text();
} else if ("boost".equals(currentFieldName)) {
fboost = parser.floatValue();
}
}
}
if (script == null) {
throw new QueryParsingException(parseContext.index(), "[custom_filters_score] missing 'script' in filters array element");
if (script == null && fboost == -1) {
throw new QueryParsingException(parseContext.index(), "[custom_filters_score] missing 'script' or 'boost' in filters array element");
}
if (filter == null) {
throw new QueryParsingException(parseContext.index(), "[custom_filters_score] missing 'filter' in filters array element");
}
filters.add(filter);
scripts.add(script);
boosts.add(fboost);
}
}
} else if (token.isValue()) {
@ -118,8 +126,15 @@ public class CustomFiltersScoreQueryParser implements QueryParser {
}
FiltersFunctionScoreQuery.FilterFunction[] filterFunctions = new FiltersFunctionScoreQuery.FilterFunction[filters.size()];
for (int i = 0; i < filterFunctions.length; i++) {
SearchScript searchScript = context.scriptService().search(context.lookup(), scriptLang, scripts.get(i), vars);
filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), new CustomScoreQueryParser.ScriptScoreFunction(searchScript));
ScoreFunction scoreFunction;
String script = scripts.get(i);
if (script != null) {
SearchScript searchScript = context.scriptService().search(context.lookup(), scriptLang, script, vars);
scoreFunction = new CustomScoreQueryParser.ScriptScoreFunction(searchScript);
} else {
scoreFunction = new BoostScoreFunction(boosts.get(i));
}
filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(filters.get(i), scoreFunction);
}
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, filterFunctions);
functionScoreQuery.setBoost(boost);

View File

@ -174,5 +174,25 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f));
assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3")));
assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f));
searchResponse = client.prepareSearch("test")
.setQuery(customFiltersScoreQuery(matchAllQuery())
.add(termFilter("field", "value4"), 2)
.add(termFilter("field", "value2"), 3))
.setExplain(true)
.execute().actionGet();
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
assertThat(searchResponse.hits().getAt(0).id(), equalTo("2"));
assertThat(searchResponse.hits().getAt(0).score(), equalTo(3.0f));
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
assertThat(searchResponse.hits().getAt(1).id(), equalTo("4"));
assertThat(searchResponse.hits().getAt(1).score(), equalTo(2.0f));
assertThat(searchResponse.hits().getAt(2).id(), anyOf(equalTo("1"), equalTo("3")));
assertThat(searchResponse.hits().getAt(2).score(), equalTo(1.0f));
assertThat(searchResponse.hits().getAt(3).id(), anyOf(equalTo("1"), equalTo("3")));
assertThat(searchResponse.hits().getAt(3).score(), equalTo(1.0f));
}
}