allow to pass parameters to custom score script
This commit is contained in:
parent
1b32c1ccf4
commit
953779ccea
|
@ -19,9 +19,11 @@
|
|||
|
||||
package org.elasticsearch.index.query.xcontent;
|
||||
|
||||
import org.elasticsearch.util.collect.Maps;
|
||||
import org.elasticsearch.util.xcontent.builder.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A query that uses a script to compute the score.
|
||||
|
@ -36,6 +38,8 @@ public class CustomScoreQueryBuilder extends BaseQueryBuilder {
|
|||
|
||||
private float boost = -1;
|
||||
|
||||
private Map<String, Object> params = null;
|
||||
|
||||
/**
|
||||
* A query that simply applies the boost factor to another query (multiply it).
|
||||
*
|
||||
|
@ -53,6 +57,29 @@ public class CustomScoreQueryBuilder extends BaseQueryBuilder {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Additional parameters that can be provided to the script.
|
||||
*/
|
||||
public CustomScoreQueryBuilder params(Map<String, Object> params) {
|
||||
if (params == null) {
|
||||
this.params = params;
|
||||
} else {
|
||||
this.params.putAll(params);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Additional parameters that can be provided to the script.
|
||||
*/
|
||||
public CustomScoreQueryBuilder param(String key, Object value) {
|
||||
if (params == null) {
|
||||
params = Maps.newHashMap();
|
||||
}
|
||||
params.put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the boost for this query. Documents matching this query will (in addition to the normal
|
||||
* weightings) have their score multiplied by the boost provided.
|
||||
|
@ -67,6 +94,10 @@ public class CustomScoreQueryBuilder extends BaseQueryBuilder {
|
|||
builder.field("query");
|
||||
queryBuilder.toXContent(builder, params);
|
||||
builder.field("script", script);
|
||||
if (this.params != null) {
|
||||
builder.field("params");
|
||||
builder.map(this.params);
|
||||
}
|
||||
if (boost != -1) {
|
||||
builder.field("boost", boost);
|
||||
}
|
||||
|
|
|
@ -60,6 +60,7 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
|
|||
Query query = null;
|
||||
float boost = 1.0f;
|
||||
String script = null;
|
||||
Map<String, Object> vars = null;
|
||||
|
||||
String currentFieldName = null;
|
||||
XContentParser.Token token;
|
||||
|
@ -69,6 +70,8 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
|
|||
} else if (token == XContentParser.Token.START_OBJECT) {
|
||||
if ("query".equals(currentFieldName)) {
|
||||
query = parseContext.parseInnerQuery();
|
||||
} else if ("params".equals(currentFieldName)) {
|
||||
vars = parser.map();
|
||||
}
|
||||
} else if (token.isValue()) {
|
||||
if ("script".equals(currentFieldName)) {
|
||||
|
@ -85,7 +88,7 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
|
|||
throw new QueryParsingException(index, "[custom_score] requires 'script' field");
|
||||
}
|
||||
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(query,
|
||||
new ScriptScoreFunction(new ScriptFieldsFunction(script, parseContext.scriptService(), parseContext.mapperService(), parseContext.indexCache().fieldData())));
|
||||
new ScriptScoreFunction(new ScriptFieldsFunction(script, parseContext.scriptService(), parseContext.mapperService(), parseContext.indexCache().fieldData()), vars));
|
||||
functionScoreQuery.setBoost(boost);
|
||||
return functionScoreQuery;
|
||||
}
|
||||
|
@ -102,14 +105,17 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
|
|||
|
||||
private Map<String, Object> vars;
|
||||
|
||||
private ScriptScoreFunction(ScriptFieldsFunction scriptFieldsFunction) {
|
||||
private ScriptScoreFunction(ScriptFieldsFunction scriptFieldsFunction, Map<String, Object> vars) {
|
||||
this.scriptFieldsFunction = scriptFieldsFunction;
|
||||
this.vars = vars;
|
||||
}
|
||||
|
||||
@Override public void setNextReader(IndexReader reader) {
|
||||
scriptFieldsFunction.setNextReader(reader);
|
||||
vars = cachedVars.get().get();
|
||||
vars.clear();
|
||||
if (vars == null) {
|
||||
vars = cachedVars.get().get();
|
||||
vars.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@Override public float score(int docId, float subQueryScore) {
|
||||
|
|
|
@ -142,5 +142,17 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
|
|||
logger.info("Hit[1] {} Explanation {}", response.hits().getAt(1).id(), response.hits().getAt(1).explanation());
|
||||
assertThat(response.hits().getAt(0).id(), equalTo("2"));
|
||||
assertThat(response.hits().getAt(1).id(), equalTo("1"));
|
||||
|
||||
logger.info("running param1 * param2 * score");
|
||||
response = client.search(searchRequest()
|
||||
.searchType(SearchType.QUERY_THEN_FETCH)
|
||||
.source(searchSource().explain(true).query(customScoreQuery(termQuery("test", "value")).script("param1 * param2 * score").param("param1", 2).param("param2", 2)))
|
||||
).actionGet();
|
||||
|
||||
assertThat(response.hits().totalHits(), equalTo(2l));
|
||||
logger.info("Hit[0] {} Explanation {}", response.hits().getAt(0).id(), response.hits().getAt(0).explanation());
|
||||
logger.info("Hit[1] {} Explanation {}", response.hits().getAt(1).id(), response.hits().getAt(1).explanation());
|
||||
assertThat(response.hits().getAt(0).id(), equalTo("1"));
|
||||
assertThat(response.hits().getAt(1).id(), equalTo("2"));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue