allow to pass parameters to custom score script

This commit is contained in:
kimchy 2010-06-14 04:12:57 +03:00
parent 1b32c1ccf4
commit 953779ccea
3 changed files with 53 additions and 4 deletions

View File

@ -19,9 +19,11 @@
package org.elasticsearch.index.query.xcontent;
import org.elasticsearch.util.collect.Maps;
import org.elasticsearch.util.xcontent.builder.XContentBuilder;
import java.io.IOException;
import java.util.Map;
/**
* A query that uses a script to compute the score.
@ -36,6 +38,8 @@ public class CustomScoreQueryBuilder extends BaseQueryBuilder {
private float boost = -1;
private Map<String, Object> params = null;
/**
* A query that simply applies the boost factor to another query (multiply it).
*
@ -53,6 +57,29 @@ public class CustomScoreQueryBuilder extends BaseQueryBuilder {
return this;
}
/**
* Additional parameters that can be provided to the script.
*/
public CustomScoreQueryBuilder params(Map<String, Object> params) {
if (params == null) {
this.params = params;
} else {
this.params.putAll(params);
}
return this;
}
/**
* Additional parameters that can be provided to the script.
*/
public CustomScoreQueryBuilder param(String key, Object value) {
if (params == null) {
params = Maps.newHashMap();
}
params.put(key, value);
return this;
}
/**
* Sets the boost for this query. Documents matching this query will (in addition to the normal
* weightings) have their score multiplied by the boost provided.
@ -67,6 +94,10 @@ public class CustomScoreQueryBuilder extends BaseQueryBuilder {
builder.field("query");
queryBuilder.toXContent(builder, params);
builder.field("script", script);
if (this.params != null) {
builder.field("params");
builder.map(this.params);
}
if (boost != -1) {
builder.field("boost", boost);
}

View File

@ -60,6 +60,7 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
Query query = null;
float boost = 1.0f;
String script = null;
Map<String, Object> vars = null;
String currentFieldName = null;
XContentParser.Token token;
@ -69,6 +70,8 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
} else if (token == XContentParser.Token.START_OBJECT) {
if ("query".equals(currentFieldName)) {
query = parseContext.parseInnerQuery();
} else if ("params".equals(currentFieldName)) {
vars = parser.map();
}
} else if (token.isValue()) {
if ("script".equals(currentFieldName)) {
@ -85,7 +88,7 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
throw new QueryParsingException(index, "[custom_score] requires 'script' field");
}
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(query,
new ScriptScoreFunction(new ScriptFieldsFunction(script, parseContext.scriptService(), parseContext.mapperService(), parseContext.indexCache().fieldData())));
new ScriptScoreFunction(new ScriptFieldsFunction(script, parseContext.scriptService(), parseContext.mapperService(), parseContext.indexCache().fieldData()), vars));
functionScoreQuery.setBoost(boost);
return functionScoreQuery;
}
@ -102,14 +105,17 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
private Map<String, Object> vars;
private ScriptScoreFunction(ScriptFieldsFunction scriptFieldsFunction) {
private ScriptScoreFunction(ScriptFieldsFunction scriptFieldsFunction, Map<String, Object> vars) {
this.scriptFieldsFunction = scriptFieldsFunction;
this.vars = vars;
}
@Override public void setNextReader(IndexReader reader) {
scriptFieldsFunction.setNextReader(reader);
vars = cachedVars.get().get();
vars.clear();
if (vars == null) {
vars = cachedVars.get().get();
vars.clear();
}
}
@Override public float score(int docId, float subQueryScore) {

View File

@ -142,5 +142,17 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
logger.info("Hit[1] {} Explanation {}", response.hits().getAt(1).id(), response.hits().getAt(1).explanation());
assertThat(response.hits().getAt(0).id(), equalTo("2"));
assertThat(response.hits().getAt(1).id(), equalTo("1"));
logger.info("running param1 * param2 * score");
response = client.search(searchRequest()
.searchType(SearchType.QUERY_THEN_FETCH)
.source(searchSource().explain(true).query(customScoreQuery(termQuery("test", "value")).script("param1 * param2 * score").param("param1", 2).param("param2", 2)))
).actionGet();
assertThat(response.hits().totalHits(), equalTo(2l));
logger.info("Hit[0] {} Explanation {}", response.hits().getAt(0).id(), response.hits().getAt(0).explanation());
logger.info("Hit[1] {} Explanation {}", response.hits().getAt(1).id(), response.hits().getAt(1).explanation());
assertThat(response.hits().getAt(0).id(), equalTo("1"));
assertThat(response.hits().getAt(1).id(), equalTo("2"));
}
}