First step towards supporting templating in rank eval requests. (#20374)

This adds support for templating in rank eval requests.

Relates to #20231

Problem: In it's current state the rank-eval request API forces the user to repeat complete queries for each test request. In most use cases the structure of the query to test will be stable with only parameters changing across requests, so this looks like lots of boilerplate json for something that could be expressed in a more concise way.

Uses templating/ ScriptServices to enable users to submit only one test request template and let them only specify template parameters on a per test request basis.
This commit is contained in:
Isabel Drost-Fromm 2016-11-01 11:36:22 +01:00 committed by GitHub
parent 51a2e3bf1e
commit 0b8a2e40cb
16 changed files with 454 additions and 109 deletions

View File

@ -186,83 +186,84 @@ public final class Script implements ToXContent, Writeable {
return builder; return builder;
} }
public static Script parse(XContentParser parser, ParseFieldMatcher parseFieldMatcher) throws IOException { public static Script parse(XContentParser parser, ParseFieldMatcher matcher) {
return parse(parser, parseFieldMatcher, null); return parse(parser, matcher, null);
} }
public static Script parse(XContentParser parser, QueryParseContext context) { public static Script parse(XContentParser parser, QueryParseContext context) {
return parse(parser, context.getParseFieldMatcher(), null);
}
public static Script parse(XContentParser parser, ParseFieldMatcher parseFieldMatcher, @Nullable String lang) {
try { try {
return parse(parser, context.getParseFieldMatcher(), context.getDefaultScriptLanguage()); XContentParser.Token token = parser.currentToken();
// If the parser hasn't yet been pushed to the first token, do it now
if (token == null) {
token = parser.nextToken();
}
if (token == XContentParser.Token.VALUE_STRING) {
return new Script(parser.text(), ScriptType.INLINE, lang, null);
}
if (token != XContentParser.Token.START_OBJECT) {
throw new ElasticsearchParseException("expected a string value or an object, but found [{}] instead", token);
}
String script = null;
ScriptType type = null;
Map<String, Object> params = null;
XContentType contentType = null;
String cfn = null;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
cfn = parser.currentName();
} else if (parseFieldMatcher.match(cfn, ScriptType.INLINE.getParseField())) {
type = ScriptType.INLINE;
if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
contentType = parser.contentType();
XContentBuilder builder = XContentFactory.contentBuilder(contentType);
script = builder.copyCurrentStructure(parser).bytes().utf8ToString();
} else {
script = parser.text();
}
} else if (parseFieldMatcher.match(cfn, ScriptType.FILE.getParseField())) {
type = ScriptType.FILE;
if (token == XContentParser.Token.VALUE_STRING) {
script = parser.text();
} else {
throw new ElasticsearchParseException("expected a string value for field [{}], but found [{}]", cfn, token);
}
} else if (parseFieldMatcher.match(cfn, ScriptType.STORED.getParseField())) {
type = ScriptType.STORED;
if (token == XContentParser.Token.VALUE_STRING) {
script = parser.text();
} else {
throw new ElasticsearchParseException("expected a string value for field [{}], but found [{}]", cfn, token);
}
} else if (parseFieldMatcher.match(cfn, ScriptField.LANG)) {
if (token == XContentParser.Token.VALUE_STRING) {
lang = parser.text();
} else {
throw new ElasticsearchParseException("expected a string value for field [{}], but found [{}]", cfn, token);
}
} else if (parseFieldMatcher.match(cfn, ScriptField.PARAMS)) {
if (token == XContentParser.Token.START_OBJECT) {
params = parser.map();
} else {
throw new ElasticsearchParseException("expected an object for field [{}], but found [{}]", cfn, token);
}
} else {
throw new ElasticsearchParseException("unexpected field [{}]", cfn);
}
}
if (script == null) {
throw new ElasticsearchParseException("expected one of [{}], [{}] or [{}] fields, but found none",
ScriptType.INLINE.getParseField() .getPreferredName(), ScriptType.FILE.getParseField().getPreferredName(),
ScriptType.STORED.getParseField() .getPreferredName());
}
return new Script(script, type, lang, params, contentType);
} catch (IOException e) { } catch (IOException e) {
throw new ParsingException(parser.getTokenLocation(), "Error parsing [" + ScriptField.SCRIPT.getPreferredName() + "] field", e); throw new ParsingException(parser.getTokenLocation(), "Error parsing [" + ScriptField.SCRIPT.getPreferredName() + "] field", e);
} }
}
public static Script parse(XContentParser parser, ParseFieldMatcher parseFieldMatcher, @Nullable String lang) throws IOException {
XContentParser.Token token = parser.currentToken();
// If the parser hasn't yet been pushed to the first token, do it now
if (token == null) {
token = parser.nextToken();
}
if (token == XContentParser.Token.VALUE_STRING) {
return new Script(parser.text(), ScriptType.INLINE, lang, null);
}
if (token != XContentParser.Token.START_OBJECT) {
throw new ElasticsearchParseException("expected a string value or an object, but found [{}] instead", token);
}
String script = null;
ScriptType type = null;
Map<String, Object> params = null;
XContentType contentType = null;
String cfn = null;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
cfn = parser.currentName();
} else if (parseFieldMatcher.match(cfn, ScriptType.INLINE.getParseField())) {
type = ScriptType.INLINE;
if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
contentType = parser.contentType();
XContentBuilder builder = XContentFactory.contentBuilder(contentType);
script = builder.copyCurrentStructure(parser).bytes().utf8ToString();
} else {
script = parser.text();
}
} else if (parseFieldMatcher.match(cfn, ScriptType.FILE.getParseField())) {
type = ScriptType.FILE;
if (token == XContentParser.Token.VALUE_STRING) {
script = parser.text();
} else {
throw new ElasticsearchParseException("expected a string value for field [{}], but found [{}]", cfn, token);
}
} else if (parseFieldMatcher.match(cfn, ScriptType.STORED.getParseField())) {
type = ScriptType.STORED;
if (token == XContentParser.Token.VALUE_STRING) {
script = parser.text();
} else {
throw new ElasticsearchParseException("expected a string value for field [{}], but found [{}]", cfn, token);
}
} else if (parseFieldMatcher.match(cfn, ScriptField.LANG)) {
if (token == XContentParser.Token.VALUE_STRING) {
lang = parser.text();
} else {
throw new ElasticsearchParseException("expected a string value for field [{}], but found [{}]", cfn, token);
}
} else if (parseFieldMatcher.match(cfn, ScriptField.PARAMS)) {
if (token == XContentParser.Token.START_OBJECT) {
params = parser.map();
} else {
throw new ElasticsearchParseException("expected an object for field [{}], but found [{}]", cfn, token);
}
} else {
throw new ElasticsearchParseException("unexpected field [{}]", cfn);
}
}
if (script == null) {
throw new ElasticsearchParseException("expected one of [{}], [{}] or [{}] fields, but found none",
ScriptType.INLINE.getParseField() .getPreferredName(), ScriptType.FILE.getParseField().getPreferredName(),
ScriptType.STORED.getParseField() .getPreferredName());
}
return new Script(script, type, lang, params, contentType);
} }
@Override @Override
@ -287,7 +288,7 @@ public final class Script implements ToXContent, Writeable {
@Override @Override
public String toString() { public String toString() {
return "[script: " + script + ", type: " + type.getParseField().getPreferredName() + ", lang: " return "[script: " + script + ", type: " + type.getParseField().getPreferredName() + ", lang: "
+ lang + ", params: " + params + "]"; + lang + ", params: " + params + ", contentType: " + contentType + "]";
} }
public interface ScriptField { public interface ScriptField {

View File

@ -51,7 +51,7 @@ public class PrecisionAtN implements RankedListQualityMetric {
/** ratings equal or above this value will be considered relevant. */ /** ratings equal or above this value will be considered relevant. */
private int relevantRatingThreshhold = 1; private int relevantRatingThreshhold = 1;
public static final String NAME = "precisionatn"; public static final String NAME = "precision_atn";
private static final ParseField SIZE_FIELD = new ParseField("size"); private static final ParseField SIZE_FIELD = new ParseField("size");
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold"); private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");

View File

@ -23,6 +23,7 @@ import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.ParseFieldMatcherSupplier; import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchExtRegistry; import org.elasticsearch.search.SearchExtRegistry;
import org.elasticsearch.search.SearchRequestParsers; import org.elasticsearch.search.SearchRequestParsers;
import org.elasticsearch.search.aggregations.AggregatorParsers; import org.elasticsearch.search.aggregations.AggregatorParsers;
@ -33,11 +34,14 @@ public class RankEvalContext implements ParseFieldMatcherSupplier {
private final SearchRequestParsers searchRequestParsers; private final SearchRequestParsers searchRequestParsers;
private final ParseFieldMatcher parseFieldMatcher; private final ParseFieldMatcher parseFieldMatcher;
private final QueryParseContext parseContext; private final QueryParseContext parseContext;
private final ScriptService scriptService;
public RankEvalContext(ParseFieldMatcher parseFieldMatcher, QueryParseContext parseContext, SearchRequestParsers searchRequestParsers) { public RankEvalContext(ParseFieldMatcher parseFieldMatcher, QueryParseContext parseContext, SearchRequestParsers searchRequestParsers,
ScriptService scriptService) {
this.parseFieldMatcher = parseFieldMatcher; this.parseFieldMatcher = parseFieldMatcher;
this.searchRequestParsers = searchRequestParsers; this.searchRequestParsers = searchRequestParsers;
this.parseContext = parseContext; this.parseContext = parseContext;
this.scriptService = scriptService;
} }
public Suggesters getSuggesters() { public Suggesters getSuggesters() {
@ -48,6 +52,14 @@ public class RankEvalContext implements ParseFieldMatcherSupplier {
return searchRequestParsers.aggParsers; return searchRequestParsers.aggParsers;
} }
public SearchRequestParsers getSearchRequestParsers() {
return searchRequestParsers;
}
public ScriptService getScriptService() {
return scriptService;
}
public SearchExtRegistry getSearchExtParsers() { public SearchExtRegistry getSearchExtParsers() {
return searchRequestParsers.searchExtParsers; return searchRequestParsers.searchExtParsers;
} }

View File

@ -20,19 +20,27 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.support.ToXContentToBytes; import org.elasticsearch.action.support.ToXContentToBytes;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Objects; import java.util.Objects;
import java.util.Map;
/** /**
* This class defines a ranking evaluation task including an id, a collection of queries to evaluate and the evaluation metric. * This class defines a ranking evaluation task including an id, a collection of queries to evaluate and the evaluation metric.
@ -42,11 +50,13 @@ import java.util.Objects;
* */ * */
public class RankEvalSpec extends ToXContentToBytes implements Writeable { public class RankEvalSpec extends ToXContentToBytes implements Writeable {
/** Collection of query specifications, that is e.g. search request templates to use for query translation. */ /** Collection of query specifications, that is e.g. search request templates to use for query translation. */
private Collection<RatedRequest> ratedRequests = new ArrayList<>(); private Collection<RatedRequest> ratedRequests = new ArrayList<>();
/** Definition of the quality metric, e.g. precision at N */ /** Definition of the quality metric, e.g. precision at N */
private RankedListQualityMetric metric; private RankedListQualityMetric metric;
/** optional: Template to base test requests on */
@Nullable
private Script template;
public RankEvalSpec() { public RankEvalSpec() {
// TODO think if no args ctor is okay // TODO think if no args ctor is okay
@ -64,6 +74,9 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
ratedRequests.add(new RatedRequest(in)); ratedRequests.add(new RatedRequest(in));
} }
metric = in.readNamedWriteable(RankedListQualityMetric.class); metric = in.readNamedWriteable(RankedListQualityMetric.class);
if (in.readBoolean()) {
template = new Script(in);
}
} }
@Override @Override
@ -73,22 +86,24 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
spec.writeTo(out); spec.writeTo(out);
} }
out.writeNamedWriteable(metric); out.writeNamedWriteable(metric);
if (template != null) {
out.writeBoolean(true);
template.writeTo(out);
} else {
out.writeBoolean(false);
}
} }
public void setEval(RankedListQualityMetric eval) { /** Set the metric to use for quality evaluation. */
this.metric = eval; public void setMetric(RankedListQualityMetric metric) {
this.metric = metric;
} }
/** Returns the precision at n configuration (containing level of n to consider).*/ /** Returns the metric to use for quality evaluation.*/
public RankedListQualityMetric getEvaluator() { public RankedListQualityMetric getMetric() {
return metric; return metric;
} }
/** Sets the precision at n configuration (containing level of n to consider).*/
public void setEvaluator(RankedListQualityMetric config) {
this.metric = config;
}
/** Returns a list of intent to query translation specifications to evaluate. */ /** Returns a list of intent to query translation specifications to evaluate. */
public Collection<RatedRequest> getSpecifications() { public Collection<RatedRequest> getSpecifications() {
return ratedRequests; return ratedRequests;
@ -98,19 +113,33 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
public void setSpecifications(Collection<RatedRequest> specifications) { public void setSpecifications(Collection<RatedRequest> specifications) {
this.ratedRequests = specifications; this.ratedRequests = specifications;
} }
/** Set the template to base test requests on. */
public void setTemplate(Script script) {
this.template = script;
}
/** Returns the template to base test requests on. */
public Script getTemplate() {
return this.template;
}
private static final ParseField TEMPLATE_FIELD = new ParseField("template");
private static final ParseField METRIC_FIELD = new ParseField("metric"); private static final ParseField METRIC_FIELD = new ParseField("metric");
private static final ParseField REQUESTS_FIELD = new ParseField("requests"); private static final ParseField REQUESTS_FIELD = new ParseField("requests");
private static final ObjectParser<RankEvalSpec, RankEvalContext> PARSER = new ObjectParser<>("rank_eval", RankEvalSpec::new); private static final ObjectParser<RankEvalSpec, RankEvalContext> PARSER = new ObjectParser<>("rank_eval", RankEvalSpec::new);
static { static {
PARSER.declareObject(RankEvalSpec::setEvaluator, (p, c) -> { PARSER.declareObject(RankEvalSpec::setMetric, (p, c) -> {
try { try {
return RankedListQualityMetric.fromXContent(p, c); return RankedListQualityMetric.fromXContent(p, c);
} catch (IOException ex) { } catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex); throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
} }
} , METRIC_FIELD); } , METRIC_FIELD);
PARSER.declareObject(RankEvalSpec::setTemplate, (p, c) -> {
return Script.parse(p, c.getParseFieldMatcher(), "mustache");
}, TEMPLATE_FIELD);
PARSER.declareObjectArray(RankEvalSpec::setSpecifications, (p, c) -> { PARSER.declareObjectArray(RankEvalSpec::setSpecifications, (p, c) -> {
try { try {
return RatedRequest.fromXContent(p, c); return RatedRequest.fromXContent(p, c);
@ -120,9 +149,43 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
} , REQUESTS_FIELD); } , REQUESTS_FIELD);
} }
public static RankEvalSpec parse(XContentParser parser, RankEvalContext context, boolean templated) throws IOException {
RankEvalSpec spec = PARSER.parse(parser, context);
if (templated) {
for (RatedRequest query_spec : spec.getSpecifications()) {
Map<String, String> params = query_spec.getParams();
Script scriptWithParams = new Script(spec.template.getScript(), spec.template.getType(), spec.template.getLang(), params);
String resolvedRequest =
((BytesReference)
(context.getScriptService().executable(scriptWithParams, ScriptContext.Standard.SEARCH, params)
.run()))
.utf8ToString();
try (XContentParser subParser = XContentFactory.xContent(resolvedRequest).createParser(resolvedRequest)) {
QueryParseContext parseContext =
new QueryParseContext(
context.getSearchRequestParsers().queryParsers,
subParser,
context.getParseFieldMatcher());
SearchSourceBuilder templateResult =
SearchSourceBuilder.fromXContent(
parseContext,
context.getAggs(),
context.getSuggesters(),
context.getSearchExtParsers());
query_spec.setTestRequest(templateResult);
}
}
}
return spec;
}
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
if (this.template != null) {
builder.field(TEMPLATE_FIELD.getPreferredName(), this.template);
}
builder.startArray(REQUESTS_FIELD.getPreferredName()); builder.startArray(REQUESTS_FIELD.getPreferredName());
for (RatedRequest spec : this.ratedRequests) { for (RatedRequest spec : this.ratedRequests) {
spec.toXContent(builder, params); spec.toXContent(builder, params);
@ -133,10 +196,6 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
return builder; return builder;
} }
public static RankEvalSpec parse(XContentParser parser, RankEvalContext context) throws IOException {
return PARSER.parse(parser, context);
}
@Override @Override
public final boolean equals(Object obj) { public final boolean equals(Object obj) {
if (this == obj) { if (this == obj) {
@ -146,12 +205,14 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
return false; return false;
} }
RankEvalSpec other = (RankEvalSpec) obj; RankEvalSpec other = (RankEvalSpec) obj;
return Objects.equals(ratedRequests, other.ratedRequests) && return Objects.equals(ratedRequests, other.ratedRequests) &&
Objects.equals(metric, other.metric); Objects.equals(metric, other.metric) &&
Objects.equals(template, other.template);
} }
@Override @Override
public final int hashCode() { public final int hashCode() {
return Objects.hash(ratedRequests, metric); return Objects.hash(ratedRequests, metric, template);
} }
} }

View File

@ -32,17 +32,21 @@ import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Map;
import java.util.Map.Entry;
/** /**
* Defines a QA specification: All end user supplied query intents will be mapped to the search request specified in this search request * Defines a QA specification: All end user supplied query intents will be mapped to the search request specified in this search request
* template and executed against the targetIndex given. Any filters that should be applied in the target system can be specified as well. * template and executed against the targetIndex given. Any filters that should be applied in the target system can be specified as well.
* *
* The resulting document lists can then be compared against what was specified in the set of rated documents as part of a QAQuery. * The resulting document lists can then be compared against what was specified in the set of rated documents as part of a QAQuery.
* */ * */
@SuppressWarnings("unchecked")
public class RatedRequest extends ToXContentToBytes implements Writeable { public class RatedRequest extends ToXContentToBytes implements Writeable {
private String specId; private String specId;
private SearchSourceBuilder testRequest; private SearchSourceBuilder testRequest;
private List<String> indices = new ArrayList<>(); private List<String> indices = new ArrayList<>();
@ -50,6 +54,8 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
private List<String> summaryFields = new ArrayList<>(); private List<String> summaryFields = new ArrayList<>();
/** Collection of rated queries for this query QA specification.*/ /** Collection of rated queries for this query QA specification.*/
private List<RatedDocument> ratedDocs = new ArrayList<>(); private List<RatedDocument> ratedDocs = new ArrayList<>();
/** Map of parameters to use for filling a query template, can be used instead of providing testRequest. */
private Map<String, String> params = new HashMap<>();
public RatedRequest() { public RatedRequest() {
// ctor that doesn't require all args to be present immediatly is easier to use with ObjectParser // ctor that doesn't require all args to be present immediatly is easier to use with ObjectParser
@ -83,6 +89,7 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
for (int i = 0; i < intentSize; i++) { for (int i = 0; i < intentSize; i++) {
ratedDocs.add(new RatedDocument(in)); ratedDocs.add(new RatedDocument(in));
} }
this.params = (Map) in.readMap();
int summaryFieldsSize = in.readInt(); int summaryFieldsSize = in.readInt();
summaryFields = new ArrayList<>(summaryFieldsSize); summaryFields = new ArrayList<>(summaryFieldsSize);
for (int i = 0; i < summaryFieldsSize; i++) { for (int i = 0; i < summaryFieldsSize; i++) {
@ -106,6 +113,7 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
for (RatedDocument ratedDoc : ratedDocs) { for (RatedDocument ratedDoc : ratedDocs) {
ratedDoc.writeTo(out); ratedDoc.writeTo(out);
} }
out.writeMap((Map) params);
out.writeInt(summaryFields.size()); out.writeInt(summaryFields.size());
for (String fieldName : summaryFields) { for (String fieldName : summaryFields) {
out.writeString(fieldName); out.writeString(fieldName);
@ -155,6 +163,14 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
public void setRatedDocs(List<RatedDocument> ratedDocs) { public void setRatedDocs(List<RatedDocument> ratedDocs) {
this.ratedDocs = ratedDocs; this.ratedDocs = ratedDocs;
} }
public void setParams(Map<String, String> params) {
this.params = params;
}
public Map<String, String> getParams() {
return this.params;
}
public void setSummaryFields(List<String> fields) { public void setSummaryFields(List<String> fields) {
this.summaryFields = fields; this.summaryFields = fields;
@ -168,6 +184,7 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
private static final ParseField ID_FIELD = new ParseField("id"); private static final ParseField ID_FIELD = new ParseField("id");
private static final ParseField REQUEST_FIELD = new ParseField("request"); private static final ParseField REQUEST_FIELD = new ParseField("request");
private static final ParseField RATINGS_FIELD = new ParseField("ratings"); private static final ParseField RATINGS_FIELD = new ParseField("ratings");
private static final ParseField PARAMS_FIELD = new ParseField("params");
private static final ParseField FIELDS_FIELD = new ParseField("summary_fields"); private static final ParseField FIELDS_FIELD = new ParseField("summary_fields");
private static final ObjectParser<RatedRequest, RankEvalContext> PARSER = new ObjectParser<>("requests", RatedRequest::new); private static final ObjectParser<RatedRequest, RankEvalContext> PARSER = new ObjectParser<>("requests", RatedRequest::new);
@ -187,7 +204,14 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
} catch (IOException ex) { } catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing ratings", ex); throw new ParsingException(p.getTokenLocation(), "error parsing ratings", ex);
} }
} , RATINGS_FIELD); }, RATINGS_FIELD);
PARSER.declareObject(RatedRequest::setParams, (p, c) -> {
try {
return (Map) p.map();
} catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing ratings", ex);
}
}, PARAMS_FIELD);
} }
/** /**
@ -219,7 +243,13 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ID_FIELD.getPreferredName(), this.specId); builder.field(ID_FIELD.getPreferredName(), this.specId);
builder.field(REQUEST_FIELD.getPreferredName(), this.testRequest); if (testRequest != null)
builder.field(REQUEST_FIELD.getPreferredName(), this.testRequest);
builder.startObject(PARAMS_FIELD.getPreferredName());
for (Entry<String, String> entry : this.params.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
builder.startArray(RATINGS_FIELD.getPreferredName()); builder.startArray(RATINGS_FIELD.getPreferredName());
for (RatedDocument doc : this.ratedDocs) { for (RatedDocument doc : this.ratedDocs) {
doc.toXContent(builder, params); doc.toXContent(builder, params);
@ -250,11 +280,12 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
Objects.equals(indices, other.indices) && Objects.equals(indices, other.indices) &&
Objects.equals(types, other.types) && Objects.equals(types, other.types) &&
Objects.equals(summaryFields, summaryFields) && Objects.equals(summaryFields, summaryFields) &&
Objects.equals(ratedDocs, other.ratedDocs); Objects.equals(ratedDocs, other.ratedDocs) &&
Objects.equals(params, other.params);
} }
@Override @Override
public final int hashCode() { public final int hashCode() {
return Objects.hash(specId, testRequest, indices, types, summaryFields, ratedDocs); return Objects.hash(specId, testRequest, indices, types, summaryFields, ratedDocs, params);
} }
} }

View File

@ -32,6 +32,7 @@ import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestActions; import org.elasticsearch.rest.action.RestActions;
import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchRequestParsers; import org.elasticsearch.search.SearchRequestParsers;
import java.io.IOException; import java.io.IOException;
@ -156,19 +157,31 @@ import static org.elasticsearch.rest.RestRequest.Method.POST;
* */ * */
public class RestRankEvalAction extends BaseRestHandler { public class RestRankEvalAction extends BaseRestHandler {
private SearchRequestParsers searchRequestParsers; private SearchRequestParsers searchRequestParsers;
private ScriptService scriptService;
@Inject @Inject
public RestRankEvalAction(Settings settings, RestController controller, SearchRequestParsers searchRequestParsers) { public RestRankEvalAction(
Settings settings,
RestController controller,
SearchRequestParsers searchRequestParsers,
ScriptService scriptService) {
super(settings); super(settings);
this.searchRequestParsers = searchRequestParsers; this.searchRequestParsers = searchRequestParsers;
this.scriptService = scriptService;
controller.registerHandler(GET, "/_rank_eval", this); controller.registerHandler(GET, "/_rank_eval", this);
controller.registerHandler(POST, "/_rank_eval", this); controller.registerHandler(POST, "/_rank_eval", this);
controller.registerHandler(GET, "/{index}/_rank_eval", this); controller.registerHandler(GET, "/{index}/_rank_eval", this);
controller.registerHandler(POST, "/{index}/_rank_eval", this); controller.registerHandler(POST, "/{index}/_rank_eval", this);
controller.registerHandler(GET, "/{index}/{type}/_rank_eval", this); controller.registerHandler(GET, "/{index}/{type}/_rank_eval", this);
controller.registerHandler(POST, "/{index}/{type}/_rank_eval", this); controller.registerHandler(POST, "/{index}/{type}/_rank_eval", this);
controller.registerHandler(GET, "/_rank_eval/template", this);
controller.registerHandler(POST, "/_rank_eval/template", this);
controller.registerHandler(GET, "/{index}/_rank_eval/template", this);
controller.registerHandler(POST, "/{index}/_rank_eval/template", this);
controller.registerHandler(GET, "/{index}/{type}/_rank_eval/template", this);
controller.registerHandler(POST, "/{index}/{type}/_rank_eval/template", this);
} }
@Override @Override
@ -180,7 +193,7 @@ public class RestRankEvalAction extends BaseRestHandler {
if (restContent != null) { if (restContent != null) {
parseRankEvalRequest(rankEvalRequest, request, parseRankEvalRequest(rankEvalRequest, request,
// TODO can we get rid of aggregators parsers and suggesters? // TODO can we get rid of aggregators parsers and suggesters?
new RankEvalContext(parseFieldMatcher, parseContext, searchRequestParsers)); new RankEvalContext(parseFieldMatcher, parseContext, searchRequestParsers, scriptService));
} }
} }
return channel -> client.executeLocally(RankEvalAction.INSTANCE, rankEvalRequest, return channel -> client.executeLocally(RankEvalAction.INSTANCE, rankEvalRequest,
@ -191,7 +204,9 @@ public class RestRankEvalAction extends BaseRestHandler {
throws IOException { throws IOException {
List<String> indices = Arrays.asList(Strings.splitStringByCommaToArray(request.param("index"))); List<String> indices = Arrays.asList(Strings.splitStringByCommaToArray(request.param("index")));
List<String> types = Arrays.asList(Strings.splitStringByCommaToArray(request.param("type"))); List<String> types = Arrays.asList(Strings.splitStringByCommaToArray(request.param("type")));
RankEvalSpec spec = RankEvalSpec.parse(context.parser(), context); RankEvalSpec spec = null;
boolean containsTemplate = request.path().contains("template");
spec = RankEvalSpec.parse(context.parser(), context, containsTemplate);
for (RatedRequest specification : spec.getSpecifications()) { for (RatedRequest specification : spec.getSpecifications()) {
specification.setIndices(indices); specification.setIndices(indices);
specification.setTypes(types); specification.setTypes(types);

View File

@ -111,14 +111,14 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
@Override @Override
public void onResponse(SearchResponse searchResponse) { public void onResponse(SearchResponse searchResponse) {
SearchHit[] hits = searchResponse.getHits().getHits(); SearchHit[] hits = searchResponse.getHits().getHits();
EvalQueryQuality queryQuality = task.getEvaluator().evaluate(specification.getSpecId(), hits, EvalQueryQuality queryQuality = task.getMetric().evaluate(specification.getSpecId(), hits,
specification.getRatedDocs()); specification.getRatedDocs());
requestDetails.put(specification.getSpecId(), queryQuality); requestDetails.put(specification.getSpecId(), queryQuality);
if (responseCounter.decrementAndGet() < 1) { if (responseCounter.decrementAndGet() < 1) {
// TODO add other statistics like micro/macro avg? // TODO add other statistics like micro/macro avg?
listener.onResponse( listener.onResponse(
new RankEvalResponse(task.getEvaluator().combine(requestDetails.values()), requestDetails)); new RankEvalResponse(task.getMetric().combine(requestDetails.values()), requestDetails));
} }
} }

View File

@ -22,9 +22,17 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher; import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ParseFieldRegistry; import org.elasticsearch.common.xcontent.ParseFieldRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.indices.query.IndicesQueriesRegistry; import org.elasticsearch.indices.query.IndicesQueriesRegistry;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.SearchRequestParsers; import org.elasticsearch.search.SearchRequestParsers;
import org.elasticsearch.search.aggregations.AggregatorParsers; import org.elasticsearch.search.aggregations.AggregatorParsers;
@ -35,7 +43,9 @@ import org.junit.BeforeClass;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
@ -89,13 +99,45 @@ public class RankEvalSpecTests extends ESTestCase {
RankEvalSpec testItem = new RankEvalSpec(specs, metric); RankEvalSpec testItem = new RankEvalSpec(specs, metric);
XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem); XContentType contentType = ESTestCase.randomFrom(XContentType.values());
XContent xContent = contentType.xContent();
if (randomBoolean()) {
final Map<String, Object> params = randomBoolean() ? null : Collections.singletonMap("key", "value");
ScriptService.ScriptType scriptType = randomFrom(ScriptService.ScriptType.values());
String script;
if (scriptType == ScriptService.ScriptType.INLINE) {
try (XContentBuilder builder = XContentBuilder.builder(xContent)) {
builder.startObject();
builder.field("field", randomAsciiOfLengthBetween(1, 5));
builder.endObject();
script = builder.string();
}
} else {
script = randomAsciiOfLengthBetween(1, 5);
}
testItem.setTemplate(new Script(
script,
scriptType,
randomFrom("_lang1", "_lang2", null),
params,
scriptType == ScriptService.ScriptType.INLINE ? xContent.type() : null));
}
XContentBuilder builder = XContentFactory.contentBuilder(contentType);
if (ESTestCase.randomBoolean()) {
builder.prettyPrint();
}
testItem.toXContent(builder, ToXContent.EMPTY_PARAMS);
XContentBuilder shuffled = ESTestCase.shuffleXContent(builder);
XContentParser itemParser = XContentHelper.createParser(shuffled.bytes());
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT); QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT);
RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext, RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext,
searchRequestParsers); searchRequestParsers, null);
RankEvalSpec parsedItem = RankEvalSpec.parse(itemParser, rankContext); RankEvalSpec parsedItem = RankEvalSpec.parse(itemParser, rankContext, false);
// IRL these come from URL parameters - see RestRankEvalAction // IRL these come from URL parameters - see RestRankEvalAction
parsedItem.getSpecifications().stream().forEach(e -> {e.setIndices(indices); e.setTypes(types);}); parsedItem.getSpecifications().stream().forEach(e -> {e.setIndices(indices); e.setTypes(types);});
assertNotSame(testItem, parsedItem); assertNotSame(testItem, parsedItem);

View File

@ -38,7 +38,9 @@ import org.junit.BeforeClass;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
@ -81,6 +83,16 @@ public class RatedRequestsTests extends ESTestCase {
RatedRequest ratedRequest = new RatedRequest(specId, testRequest, indices, types, ratedDocs); RatedRequest ratedRequest = new RatedRequest(specId, testRequest, indices, types, ratedDocs);
if (randomBoolean()) {
Map<String, String> params = new HashMap<String, String>();
int randomSize = randomIntBetween(1, 10);
for (int i = 0; i < randomSize; i++) {
params.put(randomAsciiOfLengthBetween(1, 10), randomAsciiOfLengthBetween(1, 10));
}
ratedRequest.setParams(params);
}
List<String> summaryFields = new ArrayList<>(); List<String> summaryFields = new ArrayList<>();
int numSummaryFields = randomIntBetween(0, 5); int numSummaryFields = randomIntBetween(0, 5);
for (int i = 0; i < numSummaryFields; i++) { for (int i = 0; i < numSummaryFields; i++) {
@ -109,7 +121,7 @@ public class RatedRequestsTests extends ESTestCase {
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT); QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT);
RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext, RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext,
searchRequestParsers); searchRequestParsers, null);
RatedRequest parsedItem = RatedRequest.fromXContent(itemParser, rankContext); RatedRequest parsedItem = RatedRequest.fromXContent(itemParser, rankContext);
parsedItem.setIndices(indices); // IRL these come from URL parameters - see RestRankEvalAction parsedItem.setIndices(indices); // IRL these come from URL parameters - see RestRankEvalAction
@ -143,7 +155,7 @@ public class RatedRequestsTests extends ESTestCase {
XContentParser parser = XContentFactory.xContent(querySpecString).createParser(querySpecString); XContentParser parser = XContentFactory.xContent(querySpecString).createParser(querySpecString);
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, parser, ParseFieldMatcher.STRICT); QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, parser, ParseFieldMatcher.STRICT);
RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext, RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext,
searchRequestParsers); searchRequestParsers, null);
RatedRequest specification = RatedRequest.fromXContent(parser, rankContext); RatedRequest specification = RatedRequest.fromXContent(parser, rankContext);
assertEquals("my_qa_query", specification.getSpecId()); assertEquals("my_qa_query", specification.getSpecId());
assertNotNull(specification.getTestRequest()); assertNotNull(specification.getTestRequest());

View File

@ -56,7 +56,7 @@
"ratings": [{"_index": "foo", "_type": "bar", "_id": "doc1", "rating": 1}] "ratings": [{"_index": "foo", "_type": "bar", "_id": "doc1", "rating": 1}]
} }
], ],
"metric" : { "precisionatn": { "size": 10}} "metric" : { "precision_atn": { "size": 10}}
} }
- match: { rank_eval.quality_level: 1} - match: { rank_eval.quality_level: 1}

View File

@ -0,0 +1,27 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
apply plugin: 'elasticsearch.rest-test'
/*
dependencies {
testCompile project(path: ':modules:rank-eval', configuration: 'runtime')
testCompile project(path: ':modules:lang-mustache', configuration: 'runtime')
}
*/

View File

@ -0,0 +1,42 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.smoketest;
import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase;
import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException;
import java.io.IOException;
public class SmokeTestRankEvalWithMustacheYAMLTestSuiteIT extends ESClientYamlSuiteTestCase {
public SmokeTestRankEvalWithMustacheYAMLTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
super(testCandidate);
}
@ParametersFactory
public static Iterable<Object[]> parameters() throws IOException, ClientYamlTestParseException {
return ESClientYamlSuiteTestCase.createParameters(0, 1);
}
}

View File

@ -0,0 +1,67 @@
---
"Template request":
- do:
indices.create:
index: foo
body:
settings:
index:
number_of_shards: 1
- do:
index:
index: foo
type: bar
id: doc1
body: { "text": "berlin" }
- do:
index:
index: foo
type: bar
id: doc2
body: { "text": "amsterdam" }
- do:
index:
index: foo
type: bar
id: doc3
body: { "text": "amsterdam" }
- do:
index:
index: foo
type: bar
id: doc4
body: { "text": "something about amsterdam and berlin" }
- do:
indices.refresh: {}
- do:
rank_eval_template:
body: {
"template": {
"inline": "{\"query\": { \"match\" : {\"text\" : \"{{query_string}}\" }}}"
},
"requests" : [
{
"id": "amsterdam_query",
"params": { "query_string": "amsterdam" },
"ratings": [
{"_index": "foo", "_type": "bar", "_id": "doc1", "rating": 0},
{"_index": "foo", "_type": "bar", "_id": "doc2", "rating": 1},
{"_index": "foo", "_type": "bar", "_id": "doc3", "rating": 1}]
},
{
"id" : "berlin_query",
"params": { "query_string": "berlin" },
"ratings": [{"_index": "foo", "_type": "bar", "_id": "doc1", "rating": 1}]
}
],
"metric" : { "precision_atn": { "size": 10}}
}
- match: {rank_eval.quality_level: 1}
- match: {rank_eval.details.berlin_query.unknown_docs.0._id: "doc4"}
- match: {rank_eval.details.amsterdam_query.unknown_docs.0._id: "doc4"}

View File

@ -4,9 +4,18 @@
"methods": ["POST"], "methods": ["POST"],
"url": { "url": {
"path": "/_rank_eval", "path": "/_rank_eval",
"paths": ["/_rank_eval"], "paths": ["/_rank_eval", "/{index}/_rank_eval", "/{index}/{type}/_rank_eval"],
"parts": {}, "parts": {
"params": {} "index": {
"type": "list",
"description" : "A comma-separated list of index names to search; use `_all` or empty string to perform the operation on all indices"
},
"type": {
"type" : "list",
"description" : "A comma-separated list of document types to search; leave empty to perform the operation on all types"
}
},
"params": {}
}, },
"body": { "body": {
"description": "The search definition using the Query DSL and the prototype for the eval request.", "description": "The search definition using the Query DSL and the prototype for the eval request.",

View File

@ -0,0 +1,25 @@
{
"rank_eval_template": {
"documentation": "https://www.elastic.co/guide/en/elasticsearch/reference/master/docs-rank-eval.html",
"methods": ["POST"],
"url": {
"path": "/_rank_eval/template",
"paths": ["/_rank_eval/template", "/{index}/_rank_eval/template", "/{index}/{type}/_rank_eval/template"],
"parts": {
"index": {
"type": "list",
"description" : "A comma-separated list of index names to search; use `_all` or empty string to perform the operation on all indices"
},
"type": {
"type" : "list",
"description" : "A comma-separated list of document types to search; leave empty to perform the operation on all types"
}
},
"params": {}
},
"body": {
"description": "The search definition using the Query DSL and the prototype for the eval request.",
"required": true
}
}
}

View File

@ -61,6 +61,7 @@ List projects = [
'qa:smoke-test-ingest-with-all-dependencies', 'qa:smoke-test-ingest-with-all-dependencies',
'qa:smoke-test-ingest-disabled', 'qa:smoke-test-ingest-disabled',
'qa:smoke-test-multinode', 'qa:smoke-test-multinode',
'qa:smoke-test-rank-eval-with-mustache',
'qa:smoke-test-plugins', 'qa:smoke-test-plugins',
'qa:smoke-test-reindex-with-painless', 'qa:smoke-test-reindex-with-painless',
'qa:smoke-test-http', 'qa:smoke-test-http',