allow to pass parameters to custom score script
This commit is contained in:
parent
1b32c1ccf4
commit
953779ccea
|
@ -19,9 +19,11 @@
|
||||||
|
|
||||||
package org.elasticsearch.index.query.xcontent;
|
package org.elasticsearch.index.query.xcontent;
|
||||||
|
|
||||||
|
import org.elasticsearch.util.collect.Maps;
|
||||||
import org.elasticsearch.util.xcontent.builder.XContentBuilder;
|
import org.elasticsearch.util.xcontent.builder.XContentBuilder;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A query that uses a script to compute the score.
|
* A query that uses a script to compute the score.
|
||||||
|
@ -36,6 +38,8 @@ public class CustomScoreQueryBuilder extends BaseQueryBuilder {
|
||||||
|
|
||||||
private float boost = -1;
|
private float boost = -1;
|
||||||
|
|
||||||
|
private Map<String, Object> params = null;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A query that simply applies the boost factor to another query (multiply it).
|
* A query that simply applies the boost factor to another query (multiply it).
|
||||||
*
|
*
|
||||||
|
@ -53,6 +57,29 @@ public class CustomScoreQueryBuilder extends BaseQueryBuilder {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Additional parameters that can be provided to the script.
|
||||||
|
*/
|
||||||
|
public CustomScoreQueryBuilder params(Map<String, Object> params) {
|
||||||
|
if (params == null) {
|
||||||
|
this.params = params;
|
||||||
|
} else {
|
||||||
|
this.params.putAll(params);
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Additional parameters that can be provided to the script.
|
||||||
|
*/
|
||||||
|
public CustomScoreQueryBuilder param(String key, Object value) {
|
||||||
|
if (params == null) {
|
||||||
|
params = Maps.newHashMap();
|
||||||
|
}
|
||||||
|
params.put(key, value);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the boost for this query. Documents matching this query will (in addition to the normal
|
* Sets the boost for this query. Documents matching this query will (in addition to the normal
|
||||||
* weightings) have their score multiplied by the boost provided.
|
* weightings) have their score multiplied by the boost provided.
|
||||||
|
@ -67,6 +94,10 @@ public class CustomScoreQueryBuilder extends BaseQueryBuilder {
|
||||||
builder.field("query");
|
builder.field("query");
|
||||||
queryBuilder.toXContent(builder, params);
|
queryBuilder.toXContent(builder, params);
|
||||||
builder.field("script", script);
|
builder.field("script", script);
|
||||||
|
if (this.params != null) {
|
||||||
|
builder.field("params");
|
||||||
|
builder.map(this.params);
|
||||||
|
}
|
||||||
if (boost != -1) {
|
if (boost != -1) {
|
||||||
builder.field("boost", boost);
|
builder.field("boost", boost);
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,6 +60,7 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
|
||||||
Query query = null;
|
Query query = null;
|
||||||
float boost = 1.0f;
|
float boost = 1.0f;
|
||||||
String script = null;
|
String script = null;
|
||||||
|
Map<String, Object> vars = null;
|
||||||
|
|
||||||
String currentFieldName = null;
|
String currentFieldName = null;
|
||||||
XContentParser.Token token;
|
XContentParser.Token token;
|
||||||
|
@ -69,6 +70,8 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
|
||||||
} else if (token == XContentParser.Token.START_OBJECT) {
|
} else if (token == XContentParser.Token.START_OBJECT) {
|
||||||
if ("query".equals(currentFieldName)) {
|
if ("query".equals(currentFieldName)) {
|
||||||
query = parseContext.parseInnerQuery();
|
query = parseContext.parseInnerQuery();
|
||||||
|
} else if ("params".equals(currentFieldName)) {
|
||||||
|
vars = parser.map();
|
||||||
}
|
}
|
||||||
} else if (token.isValue()) {
|
} else if (token.isValue()) {
|
||||||
if ("script".equals(currentFieldName)) {
|
if ("script".equals(currentFieldName)) {
|
||||||
|
@ -85,7 +88,7 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
|
||||||
throw new QueryParsingException(index, "[custom_score] requires 'script' field");
|
throw new QueryParsingException(index, "[custom_score] requires 'script' field");
|
||||||
}
|
}
|
||||||
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(query,
|
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(query,
|
||||||
new ScriptScoreFunction(new ScriptFieldsFunction(script, parseContext.scriptService(), parseContext.mapperService(), parseContext.indexCache().fieldData())));
|
new ScriptScoreFunction(new ScriptFieldsFunction(script, parseContext.scriptService(), parseContext.mapperService(), parseContext.indexCache().fieldData()), vars));
|
||||||
functionScoreQuery.setBoost(boost);
|
functionScoreQuery.setBoost(boost);
|
||||||
return functionScoreQuery;
|
return functionScoreQuery;
|
||||||
}
|
}
|
||||||
|
@ -102,14 +105,17 @@ public class CustomScoreQueryParser extends AbstractIndexComponent implements XC
|
||||||
|
|
||||||
private Map<String, Object> vars;
|
private Map<String, Object> vars;
|
||||||
|
|
||||||
private ScriptScoreFunction(ScriptFieldsFunction scriptFieldsFunction) {
|
private ScriptScoreFunction(ScriptFieldsFunction scriptFieldsFunction, Map<String, Object> vars) {
|
||||||
this.scriptFieldsFunction = scriptFieldsFunction;
|
this.scriptFieldsFunction = scriptFieldsFunction;
|
||||||
|
this.vars = vars;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override public void setNextReader(IndexReader reader) {
|
@Override public void setNextReader(IndexReader reader) {
|
||||||
scriptFieldsFunction.setNextReader(reader);
|
scriptFieldsFunction.setNextReader(reader);
|
||||||
vars = cachedVars.get().get();
|
if (vars == null) {
|
||||||
vars.clear();
|
vars = cachedVars.get().get();
|
||||||
|
vars.clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override public float score(int docId, float subQueryScore) {
|
@Override public float score(int docId, float subQueryScore) {
|
||||||
|
|
|
@ -142,5 +142,17 @@ public class CustomScoreSearchTests extends AbstractNodesTests {
|
||||||
logger.info("Hit[1] {} Explanation {}", response.hits().getAt(1).id(), response.hits().getAt(1).explanation());
|
logger.info("Hit[1] {} Explanation {}", response.hits().getAt(1).id(), response.hits().getAt(1).explanation());
|
||||||
assertThat(response.hits().getAt(0).id(), equalTo("2"));
|
assertThat(response.hits().getAt(0).id(), equalTo("2"));
|
||||||
assertThat(response.hits().getAt(1).id(), equalTo("1"));
|
assertThat(response.hits().getAt(1).id(), equalTo("1"));
|
||||||
|
|
||||||
|
logger.info("running param1 * param2 * score");
|
||||||
|
response = client.search(searchRequest()
|
||||||
|
.searchType(SearchType.QUERY_THEN_FETCH)
|
||||||
|
.source(searchSource().explain(true).query(customScoreQuery(termQuery("test", "value")).script("param1 * param2 * score").param("param1", 2).param("param2", 2)))
|
||||||
|
).actionGet();
|
||||||
|
|
||||||
|
assertThat(response.hits().totalHits(), equalTo(2l));
|
||||||
|
logger.info("Hit[0] {} Explanation {}", response.hits().getAt(0).id(), response.hits().getAt(0).explanation());
|
||||||
|
logger.info("Hit[1] {} Explanation {}", response.hits().getAt(1).id(), response.hits().getAt(1).explanation());
|
||||||
|
assertThat(response.hits().getAt(0).id(), equalTo("1"));
|
||||||
|
assertThat(response.hits().getAt(1).id(), equalTo("2"));
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue