Updating rank-eval module after major changes on master

This commit is contained in:
Christoph Büscher 2017-02-03 20:53:30 +01:00
parent 4cb8d9d08c
commit dde2a09ba5
23 changed files with 236 additions and 371 deletions

View File

@ -20,7 +20,6 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
@ -141,7 +140,7 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize"); private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating"); private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
private static final ObjectParser<DiscountedCumulativeGain, ParseFieldMatcherSupplier> PARSER = private static final ObjectParser<DiscountedCumulativeGain, Void> PARSER =
new ObjectParser<>("dcg_at", () -> new DiscountedCumulativeGain()); new ObjectParser<>("dcg_at", () -> new DiscountedCumulativeGain());
static { static {
@ -149,8 +148,8 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
PARSER.declareInt(DiscountedCumulativeGain::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD); PARSER.declareInt(DiscountedCumulativeGain::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD);
} }
public static DiscountedCumulativeGain fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { public static DiscountedCumulativeGain fromXContent(XContentParser parser) {
return PARSER.apply(parser, matcher); return PARSER.apply(parser, null);
} }
@Override @Override

View File

@ -20,7 +20,6 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
@ -49,7 +48,7 @@ public class Precision implements RankedListQualityMetric {
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold"); private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled"); private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled");
private static final ObjectParser<Precision, ParseFieldMatcherSupplier> PARSER = new ObjectParser<>(NAME, Precision::new); private static final ObjectParser<Precision, Void> PARSER = new ObjectParser<>(NAME, Precision::new);
/** /**
* This setting controls how unlabeled documents in the search hits are * This setting controls how unlabeled documents in the search hits are
@ -118,8 +117,8 @@ public class Precision implements RankedListQualityMetric {
return ignoreUnlabeled; return ignoreUnlabeled;
} }
public static Precision fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { public static Precision fromXContent(XContentParser parser) {
return PARSER.apply(parser, matcher); return PARSER.apply(parser, null);
} }
/** Compute precisionAtN based on provided relevant document IDs. /** Compute precisionAtN based on provided relevant document IDs.

View File

@ -1,80 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchExtRegistry;
import org.elasticsearch.search.SearchRequestParsers;
import org.elasticsearch.search.aggregations.AggregatorParsers;
import org.elasticsearch.search.suggest.Suggesters;
public class RankEvalContext implements ParseFieldMatcherSupplier {
private final SearchRequestParsers searchRequestParsers;
private final ParseFieldMatcher parseFieldMatcher;
private final QueryParseContext parseContext;
private final ScriptService scriptService;
public RankEvalContext(ParseFieldMatcher parseFieldMatcher, QueryParseContext parseContext, SearchRequestParsers searchRequestParsers,
ScriptService scriptService) {
this.parseFieldMatcher = parseFieldMatcher;
this.searchRequestParsers = searchRequestParsers;
this.parseContext = parseContext;
this.scriptService = scriptService;
}
public Suggesters getSuggesters() {
return searchRequestParsers.suggesters;
}
public AggregatorParsers getAggs() {
return searchRequestParsers.aggParsers;
}
public SearchRequestParsers getSearchRequestParsers() {
return searchRequestParsers;
}
public ScriptService getScriptService() {
return scriptService;
}
public SearchExtRegistry getSearchExtParsers() {
return searchRequestParsers.searchExtParsers;
}
@Override
public ParseFieldMatcher getParseFieldMatcher() {
return this.parseFieldMatcher;
}
public XContentParser parser() {
return this.parseContext.parser();
}
public QueryParseContext getParseContext() {
return this.parseContext;
}
}

View File

@ -21,15 +21,23 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsFilter;
import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestHandler;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.function.Supplier;
public class RankEvalPlugin extends Plugin implements ActionPlugin { public class RankEvalPlugin extends Plugin implements ActionPlugin {
@ -38,9 +46,12 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
return Arrays.asList(new ActionHandler<>(RankEvalAction.INSTANCE, TransportRankEvalAction.class)); return Arrays.asList(new ActionHandler<>(RankEvalAction.INSTANCE, TransportRankEvalAction.class));
} }
@Override @Override
public List<Class<? extends RestHandler>> getRestHandlers() { public List<RestHandler> getRestHandlers(Settings settings, RestController restController, ClusterSettings clusterSettings,
return Arrays.asList(RestRankEvalAction.class); IndexScopedSettings indexScopedSettings, SettingsFilter settingsFilter, IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<DiscoveryNodes> nodesInCluster) {
return Arrays.asList(new RestRankEvalAction(settings, restController));
} }
/** /**

View File

@ -23,7 +23,7 @@ import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException; import java.io.IOException;
@ -42,7 +42,7 @@ import java.util.Map;
* *
**/ **/
//TODO instead of just returning averages over complete results, think of other statistics, micro avg, macro avg, partial results //TODO instead of just returning averages over complete results, think of other statistics, micro avg, macro avg, partial results
public class RankEvalResponse extends ActionResponse implements ToXContent { public class RankEvalResponse extends ActionResponse implements ToXContentObject {
/**Average precision observed when issuing query intents with this specification.*/ /**Average precision observed when issuing query intents with this specification.*/
private double qualityLevel; private double qualityLevel;
/**Mapping from intent id to all documents seen for this intent that were not annotated.*/ /**Mapping from intent id to all documents seen for this intent that were not annotated.*/
@ -113,6 +113,7 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startObject("rank_eval"); builder.startObject("rank_eval");
builder.field("quality_level", qualityLevel); builder.field("quality_level", qualityLevel);
builder.startObject("details"); builder.startObject("details");
@ -123,11 +124,12 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
builder.startObject("failures"); builder.startObject("failures");
for (String key : failures.keySet()) { for (String key : failures.keySet()) {
builder.startObject(key); builder.startObject(key);
ElasticsearchException.renderException(builder, params, failures.get(key)); ElasticsearchException.generateFailureXContent(builder, params, failures.get(key), true);
builder.endObject(); builder.endObject();
} }
builder.endObject(); builder.endObject();
builder.endObject(); builder.endObject();
builder.endObject();
return builder; return builder;
} }
} }

View File

@ -149,37 +149,29 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
private static final ParseField REQUESTS_FIELD = new ParseField("requests"); private static final ParseField REQUESTS_FIELD = new ParseField("requests");
private static final ParseField MAX_CONCURRENT_SEARCHES_FIELD = new ParseField("max_concurrent_searches"); private static final ParseField MAX_CONCURRENT_SEARCHES_FIELD = new ParseField("max_concurrent_searches");
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<RankEvalSpec, RankEvalContext> PARSER = private static final ConstructingObjectParser<RankEvalSpec, Void> PARSER =
new ConstructingObjectParser<>("rank_eval", new ConstructingObjectParser<>("rank_eval",
a -> new RankEvalSpec((Collection<RatedRequest>) a[0], (RankedListQualityMetric) a[1], (Collection<ScriptWithId>) a[2])); a -> new RankEvalSpec((Collection<RatedRequest>) a[0], (RankedListQualityMetric) a[1], (Collection<ScriptWithId>) a[2]));
static { static {
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> { PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> {
try { return RatedRequest.fromXContent(p);
return RatedRequest.fromXContent(p, c);
} catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
}
} , REQUESTS_FIELD); } , REQUESTS_FIELD);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> { PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
try { try {
return RankedListQualityMetric.fromXContent(p, c); return RankedListQualityMetric.fromXContent(p);
} catch (IOException ex) { } catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex); throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
} }
} , METRIC_FIELD); } , METRIC_FIELD);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> { PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
try { return ScriptWithId.fromXContent(p);
return ScriptWithId.fromXContent(p, c);
} catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
}
}, TEMPLATES_FIELD); }, TEMPLATES_FIELD);
PARSER.declareInt(RankEvalSpec::setMaxConcurrentSearches, MAX_CONCURRENT_SEARCHES_FIELD); PARSER.declareInt(RankEvalSpec::setMaxConcurrentSearches, MAX_CONCURRENT_SEARCHES_FIELD);
} }
public static RankEvalSpec parse(XContentParser parser, RankEvalContext context) throws IOException { public static RankEvalSpec parse(XContentParser parser) {
return PARSER.apply(parser, context); return PARSER.apply(parser, null);
} }
public static class ScriptWithId { public static class ScriptWithId {
@ -194,18 +186,18 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
this.script = script; this.script = script;
} }
private static final ConstructingObjectParser<ScriptWithId, RankEvalContext> PARSER = private static final ConstructingObjectParser<ScriptWithId, Void> PARSER =
new ConstructingObjectParser<>("script_with_id", a -> new ScriptWithId((String) a[0], (Script) a[1])); new ConstructingObjectParser<>("script_with_id", a -> new ScriptWithId((String) a[0], (Script) a[1]));
public static ScriptWithId fromXContent(XContentParser parser, RankEvalContext context) throws IOException { public static ScriptWithId fromXContent(XContentParser parser) {
return PARSER.apply(parser, context); return PARSER.apply(parser, null);
} }
static { static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), TEMPLATE_ID_FIELD); PARSER.declareString(ConstructingObjectParser.constructorArg(), TEMPLATE_ID_FIELD);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> { PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
try { try {
return Script.parse(p, c.getParseFieldMatcher(), "mustache"); return Script.parse(p, "mustache");
} catch (IOException ex) { } catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex); throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
} }

View File

@ -19,7 +19,6 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
@ -54,7 +53,7 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
* */ * */
EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs); EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs);
static RankedListQualityMetric fromXContent(XContentParser parser, ParseFieldMatcherSupplier context) throws IOException { static RankedListQualityMetric fromXContent(XContentParser parser) throws IOException {
RankedListQualityMetric rc; RankedListQualityMetric rc;
Token token = parser.nextToken(); Token token = parser.nextToken();
if (token != XContentParser.Token.FIELD_NAME) { if (token != XContentParser.Token.FIELD_NAME) {
@ -65,13 +64,13 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
// TODO maybe switch to using a plugable registry later? // TODO maybe switch to using a plugable registry later?
switch (metricName) { switch (metricName) {
case Precision.NAME: case Precision.NAME:
rc = Precision.fromXContent(parser, context); rc = Precision.fromXContent(parser);
break; break;
case ReciprocalRank.NAME: case ReciprocalRank.NAME:
rc = ReciprocalRank.fromXContent(parser, context); rc = ReciprocalRank.fromXContent(parser);
break; break;
case DiscountedCumulativeGain.NAME: case DiscountedCumulativeGain.NAME:
rc = DiscountedCumulativeGain.fromXContent(parser, context); rc = DiscountedCumulativeGain.fromXContent(parser);
break; break;
default: default:
throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName); throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName);

View File

@ -21,7 +21,6 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.support.ToXContentToBytes; import org.elasticsearch.action.support.ToXContentToBytes;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
@ -42,7 +41,7 @@ public class RatedDocument extends ToXContentToBytes implements Writeable {
public static final ParseField TYPE_FIELD = new ParseField("_type"); public static final ParseField TYPE_FIELD = new ParseField("_type");
public static final ParseField INDEX_FIELD = new ParseField("_index"); public static final ParseField INDEX_FIELD = new ParseField("_index");
private static final ConstructingObjectParser<RatedDocument, ParseFieldMatcherSupplier> PARSER = private static final ConstructingObjectParser<RatedDocument, Void> PARSER =
new ConstructingObjectParser<>("rated_document", new ConstructingObjectParser<>("rated_document",
a -> new RatedDocument((String) a[0], (String) a[1], (String) a[2], (Integer) a[3])); a -> new RatedDocument((String) a[0], (String) a[1], (String) a[2], (Integer) a[3]));
@ -96,8 +95,8 @@ public class RatedDocument extends ToXContentToBytes implements Writeable {
out.writeVInt(rating); out.writeVInt(rating);
} }
public static RatedDocument fromXContent(XContentParser parser, ParseFieldMatcherSupplier supplier) throws IOException { public static RatedDocument fromXContent(XContentParser parser) {
return PARSER.apply(parser, supplier); return PARSER.apply(parser, null);
} }
@Override @Override

View File

@ -29,6 +29,7 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.io.IOException; import java.io.IOException;
@ -218,22 +219,18 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
private static final ParseField FIELDS_FIELD = new ParseField("summary_fields"); private static final ParseField FIELDS_FIELD = new ParseField("summary_fields");
private static final ParseField TEMPLATE_ID_FIELD = new ParseField("template_id"); private static final ParseField TEMPLATE_ID_FIELD = new ParseField("template_id");
private static final ConstructingObjectParser<RatedRequest, RankEvalContext> PARSER = private static final ConstructingObjectParser<RatedRequest, QueryParseContext> PARSER =
new ConstructingObjectParser<>("requests", a -> new RatedRequest( new ConstructingObjectParser<>("requests", a -> new RatedRequest(
(String) a[0], (List<RatedDocument>) a[1], (SearchSourceBuilder) a[2], (Map<String, Object>) a[3], (String) a[4])); (String) a[0], (List<RatedDocument>) a[1], (SearchSourceBuilder) a[2], (Map<String, Object>) a[3], (String) a[4]));
static { static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ID_FIELD); PARSER.declareString(ConstructingObjectParser.constructorArg(), ID_FIELD);
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> { PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> {
try { return RatedDocument.fromXContent(p);
return RatedDocument.fromXContent(p, c);
} catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing ratings", ex);
}
}, RATINGS_FIELD); }, RATINGS_FIELD);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> { PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
try { try {
return SearchSourceBuilder.fromXContent(c.getParseContext(), c.getAggs(), c.getSuggesters(), c.getSearchExtParsers()); return SearchSourceBuilder.fromXContent(c);
} catch (IOException ex) { } catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing request", ex); throw new ParsingException(p.getTokenLocation(), "error parsing request", ex);
} }
@ -270,8 +267,8 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
* "ratings": [{ "1": 1 }, { "2": 0 }, { "3": 1 } ] * "ratings": [{ "1": 1 }, { "2": 0 }, { "3": 1 } ]
* } * }
*/ */
public static RatedRequest fromXContent(XContentParser parser, RankEvalContext context) throws IOException { public static RatedRequest fromXContent(XContentParser parser) {
return PARSER.apply(parser, context); return PARSER.apply(parser, new QueryParseContext(parser));
} }
@Override @Override

View File

@ -20,7 +20,6 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
@ -117,15 +116,15 @@ public class ReciprocalRank implements RankedListQualityMetric {
} }
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold"); private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
private static final ObjectParser<ReciprocalRank, ParseFieldMatcherSupplier> PARSER = new ObjectParser<>( private static final ObjectParser<ReciprocalRank, Void> PARSER = new ObjectParser<>(
"reciprocal_rank", () -> new ReciprocalRank()); "reciprocal_rank", () -> new ReciprocalRank());
static { static {
PARSER.declareInt(ReciprocalRank::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD); PARSER.declareInt(ReciprocalRank::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD);
} }
public static ReciprocalRank fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { public static ReciprocalRank fromXContent(XContentParser parser) {
return PARSER.apply(parser, matcher); return PARSER.apply(parser, null);
} }
@Override @Override

View File

@ -21,16 +21,12 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchRequestParsers;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -154,18 +150,10 @@ import static org.elasticsearch.rest.RestRequest.Method.POST;
* */ * */
public class RestRankEvalAction extends BaseRestHandler { public class RestRankEvalAction extends BaseRestHandler {
private SearchRequestParsers searchRequestParsers; //private ScriptService scriptService;
private ScriptService scriptService;
@Inject public RestRankEvalAction(Settings settings, RestController controller) {
public RestRankEvalAction(
Settings settings,
RestController controller,
SearchRequestParsers searchRequestParsers,
ScriptService scriptService) {
super(settings); super(settings);
this.searchRequestParsers = searchRequestParsers;
this.scriptService = scriptService;
controller.registerHandler(GET, "/_rank_eval", this); controller.registerHandler(GET, "/_rank_eval", this);
controller.registerHandler(POST, "/_rank_eval", this); controller.registerHandler(POST, "/_rank_eval", this);
controller.registerHandler(GET, "/{index}/_rank_eval", this); controller.registerHandler(GET, "/{index}/_rank_eval", this);
@ -178,21 +166,17 @@ public class RestRankEvalAction extends BaseRestHandler {
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
RankEvalRequest rankEvalRequest = new RankEvalRequest(); RankEvalRequest rankEvalRequest = new RankEvalRequest();
try (XContentParser parser = request.contentOrSourceParamParser()) { try (XContentParser parser = request.contentOrSourceParamParser()) {
QueryParseContext parseContext = new QueryParseContext(searchRequestParsers.queryParsers, parser, parseFieldMatcher); parseRankEvalRequest(rankEvalRequest, request, parser);
// TODO can we get rid of aggregators parsers and suggesters?
parseRankEvalRequest(rankEvalRequest, request,
new RankEvalContext(parseFieldMatcher, parseContext, searchRequestParsers, scriptService));
} }
return channel -> client.executeLocally(RankEvalAction.INSTANCE, rankEvalRequest, return channel -> client.executeLocally(RankEvalAction.INSTANCE, rankEvalRequest,
new RestToXContentListener<RankEvalResponse>(channel)); new RestToXContentListener<RankEvalResponse>(channel));
} }
public static void parseRankEvalRequest(RankEvalRequest rankEvalRequest, RestRequest request, RankEvalContext context) private static void parseRankEvalRequest(RankEvalRequest rankEvalRequest, RestRequest request, XContentParser parser) {
throws IOException {
List<String> indices = Arrays.asList(Strings.splitStringByCommaToArray(request.param("index"))); List<String> indices = Arrays.asList(Strings.splitStringByCommaToArray(request.param("index")));
List<String> types = Arrays.asList(Strings.splitStringByCommaToArray(request.param("type"))); List<String> types = Arrays.asList(Strings.splitStringByCommaToArray(request.param("type")));
RankEvalSpec spec = null; RankEvalSpec spec = null;
spec = RankEvalSpec.parse(context.parser(), context); spec = RankEvalSpec.parse(parser);
for (RatedRequest specification : spec.getRatedRequests()) { for (RatedRequest specification : spec.getRatedRequests()) {
specification.setIndices(indices); specification.setIndices(indices);
specification.setTypes(types); specification.setTypes(types);

View File

@ -29,15 +29,15 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.script.CompiledScript; import org.elasticsearch.script.CompiledScript;
import org.elasticsearch.script.Script; import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchRequestParsers;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
@ -47,12 +47,14 @@ import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Queue;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.common.xcontent.XContentHelper.createParser;
/** /**
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of * Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
* possible search requests (read: search specifications, expressed as query/search request templates) and compares the result * possible search requests (read: search specifications, expressed as query/search request templates) and compares the result
@ -66,17 +68,17 @@ import java.util.concurrent.atomic.AtomicInteger;
public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequest, RankEvalResponse> { public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequest, RankEvalResponse> {
private Client client; private Client client;
private ScriptService scriptService; private ScriptService scriptService;
private SearchRequestParsers searchRequestParsers;
Queue<RequestTask> taskQueue = new ConcurrentLinkedQueue<>(); Queue<RequestTask> taskQueue = new ConcurrentLinkedQueue<>();
private NamedXContentRegistry namedXContentRegistry;
@Inject @Inject
public TransportRankEvalAction(Settings settings, ThreadPool threadPool, ActionFilters actionFilters, public TransportRankEvalAction(Settings settings, ThreadPool threadPool, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver, Client client, TransportService transportService, IndexNameExpressionResolver indexNameExpressionResolver, Client client, TransportService transportService,
SearchRequestParsers searchRequestParsers, ScriptService scriptService) { ScriptService scriptService, NamedXContentRegistry namedXContentRegistry) {
super(settings, RankEvalAction.NAME, threadPool, transportService, actionFilters, indexNameExpressionResolver, super(settings, RankEvalAction.NAME, threadPool, transportService, actionFilters, indexNameExpressionResolver,
RankEvalRequest::new); RankEvalRequest::new);
this.searchRequestParsers = searchRequestParsers;
this.scriptService = scriptService; this.scriptService = scriptService;
this.namedXContentRegistry = namedXContentRegistry;
this.client = client; this.client = client;
} }
@ -93,7 +95,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
for (Entry<String, Script> entry : qualityTask.getTemplates().entrySet()) { for (Entry<String, Script> entry : qualityTask.getTemplates().entrySet()) {
scriptsWithoutParams.put( scriptsWithoutParams.put(
entry.getKey(), entry.getKey(),
scriptService.compile(entry.getValue(), ScriptContext.Standard.SEARCH, new HashMap<>())); scriptService.compile(entry.getValue(), ScriptContext.Standard.SEARCH));
} }
for (RatedRequest ratedRequest : ratedRequests) { for (RatedRequest ratedRequest : ratedRequests) {
@ -104,11 +106,10 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
Map<String, Object> params = ratedRequest.getParams(); Map<String, Object> params = ratedRequest.getParams();
String templateId = ratedRequest.getTemplateId(); String templateId = ratedRequest.getTemplateId();
CompiledScript compiled = scriptsWithoutParams.get(templateId); CompiledScript compiled = scriptsWithoutParams.get(templateId);
String resolvedRequest = ((BytesReference) (scriptService.executable(compiled, params).run())).utf8ToString(); BytesReference resolvedRequest = (BytesReference) (scriptService.executable(compiled, params).run());
try (XContentParser subParser = XContentFactory.xContent(resolvedRequest).createParser(resolvedRequest)) { try (XContentParser subParser = createParser(namedXContentRegistry, resolvedRequest, XContentType.JSON)) {
QueryParseContext parseContext = new QueryParseContext(searchRequestParsers.queryParsers, subParser, parseFieldMatcher); QueryParseContext parseContext = new QueryParseContext(subParser);
ratedSearchSource = SearchSourceBuilder.fromXContent(parseContext, searchRequestParsers.aggParsers, ratedSearchSource = SearchSourceBuilder.fromXContent(parseContext);
searchRequestParsers.suggesters, searchRequestParsers.searchExtParsers);
} catch (IOException e) { } catch (IOException e) {
listener.onFailure(e); listener.onFailure(e);
} }
@ -143,7 +144,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
public SearchRequest request; public SearchRequest request;
public RankEvalActionListener searchListener; public RankEvalActionListener searchListener;
public RequestTask(SearchRequest request, RankEvalActionListener listener) { RequestTask(SearchRequest request, RankEvalActionListener listener) {
this.request = request; this.request = request;
this.searchListener = listener; this.searchListener = listener;
} }

View File

@ -19,10 +19,13 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.text.Text; import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.InternalSearchHit; import org.elasticsearch.search.internal.InternalSearchHit;
@ -192,11 +195,12 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
+ " \"unknown_doc_rating\": 2,\n" + " \"unknown_doc_rating\": 2,\n"
+ " \"normalize\": true\n" + " \"normalize\": true\n"
+ "}"; + "}";
XContentParser parser = XContentFactory.xContent(xContent).createParser(xContent); try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser, () -> ParseFieldMatcher.STRICT); DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser);
assertEquals(2, dcgAt.getUnknownDocRating().intValue()); assertEquals(2, dcgAt.getUnknownDocRating().intValue());
assertEquals(true, dcgAt.getNormalize()); assertEquals(true, dcgAt.getNormalize());
} }
}
public static DiscountedCumulativeGain createTestItem() { public static DiscountedCumulativeGain createTestItem() {
boolean normalize = randomBoolean(); boolean normalize = randomBoolean();
@ -206,14 +210,17 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
} }
public void testXContentRoundtrip() throws IOException { public void testXContentRoundtrip() throws IOException {
DiscountedCumulativeGain testItem = createTestItem(); DiscountedCumulativeGain testItem = createTestItem();
XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem); XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS));
try (XContentParser itemParser = createParser(shuffled)) {
itemParser.nextToken(); itemParser.nextToken();
itemParser.nextToken(); itemParser.nextToken();
DiscountedCumulativeGain parsedItem = DiscountedCumulativeGain.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT); DiscountedCumulativeGain parsedItem = DiscountedCumulativeGain.fromXContent(itemParser);
assertNotSame(testItem, parsedItem); assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem); assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode()); assertEquals(testItem.hashCode(), parsedItem.hashCode());
} }
}
public void testSerialization() throws IOException { public void testSerialization() throws IOException {
DiscountedCumulativeGain original = createTestItem(); DiscountedCumulativeGain original = createTestItem();

View File

@ -19,10 +19,13 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.text.Text; import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.Index; import org.elasticsearch.index.Index;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
@ -153,10 +156,11 @@ public class PrecisionTests extends ESTestCase {
String xContent = " {\n" String xContent = " {\n"
+ " \"relevant_rating_threshold\" : 2" + " \"relevant_rating_threshold\" : 2"
+ "}"; + "}";
XContentParser parser = XContentFactory.xContent(xContent).createParser(xContent); try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
Precision precicionAt = Precision.fromXContent(parser, () -> ParseFieldMatcher.STRICT); Precision precicionAt = Precision.fromXContent(parser);
assertEquals(2, precicionAt.getRelevantRatingThreshold()); assertEquals(2, precicionAt.getRelevantRatingThreshold());
} }
}
public void testCombine() { public void testCombine() {
Precision metric = new Precision(); Precision metric = new Precision();
@ -183,14 +187,17 @@ public class PrecisionTests extends ESTestCase {
public void testXContentRoundtrip() throws IOException { public void testXContentRoundtrip() throws IOException {
Precision testItem = createTestItem(); Precision testItem = createTestItem();
XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem); XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS));
try (XContentParser itemParser = createParser(shuffled)) {
itemParser.nextToken(); itemParser.nextToken();
itemParser.nextToken(); itemParser.nextToken();
Precision parsedItem = Precision.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT); Precision parsedItem = Precision.fromXContent(itemParser);
assertNotSame(testItem, parsedItem); assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem); assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode()); assertEquals(testItem.hashCode(), parsedItem.hashCode());
} }
}
public void testSerialization() throws IOException { public void testSerialization() throws IOException {
Precision original = createTestItem(); Precision original = createTestItem();

View File

@ -78,9 +78,7 @@ public class RankEvalResponseTests extends ESTestCase {
if (ESTestCase.randomBoolean()) { if (ESTestCase.randomBoolean()) {
builder.prettyPrint(); builder.prettyPrint();
} }
builder.startObject();
randomResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); randomResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject();
// TODO check the correctness of the output // TODO check the correctness of the output
} }
} }

View File

@ -19,30 +19,19 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ParseFieldRegistry;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.index.rankeval.RankEvalSpec.ScriptWithId; import org.elasticsearch.index.rankeval.RankEvalSpec.ScriptWithId;
import org.elasticsearch.indices.query.IndicesQueriesRegistry;
import org.elasticsearch.script.Script; import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType; import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.SearchRequestParsers;
import org.elasticsearch.search.aggregations.AggregatorParsers;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.suggest.Suggesters;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -56,30 +45,7 @@ import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.function.Supplier; import java.util.function.Supplier;
import static java.util.Collections.emptyList;
public class RankEvalSpecTests extends ESTestCase { public class RankEvalSpecTests extends ESTestCase {
private static SearchModule searchModule;
private static SearchRequestParsers searchRequestParsers;
/**
* setup for the whole base test class
*/
@BeforeClass
public static void init() {
AggregatorParsers aggsParsers = new AggregatorParsers(new ParseFieldRegistry<>("aggregation"),
new ParseFieldRegistry<>("aggregation_pipes"));
searchModule = new SearchModule(Settings.EMPTY, false, emptyList());
IndicesQueriesRegistry queriesRegistry = searchModule.getQueryParserRegistry();
Suggesters suggesters = searchModule.getSuggesters();
searchRequestParsers = new SearchRequestParsers(queriesRegistry, aggsParsers, suggesters, null);
}
@AfterClass
public static void afterClass() throws Exception {
searchModule = null;
searchRequestParsers = null;
}
private static <T> List<T> randomList(Supplier<T> randomSupplier) { private static <T> List<T> randomList(Supplier<T> randomSupplier) {
List<T> result = new ArrayList<>(); List<T> result = new ArrayList<>();
@ -90,7 +56,7 @@ public class RankEvalSpecTests extends ESTestCase {
return result; return result;
} }
private RankEvalSpec createTestItem() throws IOException { private static RankEvalSpec createTestItem() throws IOException {
RankedListQualityMetric metric; RankedListQualityMetric metric;
if (randomBoolean()) { if (randomBoolean()) {
metric = PrecisionTests.createTestItem(); metric = PrecisionTests.createTestItem();
@ -118,7 +84,7 @@ public class RankEvalSpecTests extends ESTestCase {
templates = new HashSet<>(); templates = new HashSet<>();
templates.add( templates.add(
new ScriptWithId("templateId", new Script(scriptType, randomFrom("_lang1", "_lang2"), script, params))); new ScriptWithId("templateId", new Script(scriptType, Script.DEFAULT_TEMPLATE_LANG, script, params)));
Map<String, Object> templateParams = new HashMap<>(); Map<String, Object> templateParams = new HashMap<>();
templateParams.put("key", "value"); templateParams.put("key", "value");
@ -138,20 +104,19 @@ public class RankEvalSpecTests extends ESTestCase {
public void testXContentRoundtrip() throws IOException { public void testXContentRoundtrip() throws IOException {
RankEvalSpec testItem = createTestItem(); RankEvalSpec testItem = createTestItem();
XContentBuilder shuffled = ESTestCase.shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); XContentBuilder shuffled = shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
XContentParser itemParser = XContentHelper.createParser(shuffled.bytes()); try (XContentParser parser = createParser(JsonXContent.jsonXContent, shuffled.bytes())) {
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT); RankEvalSpec parsedItem = RankEvalSpec.parse(parser);
RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext,
searchRequestParsers, null);
RankEvalSpec parsedItem = RankEvalSpec.parse(itemParser, rankContext);
// IRL these come from URL parameters - see RestRankEvalAction // IRL these come from URL parameters - see RestRankEvalAction
// TODO Do we still need this? parsedItem.getRatedRequests().stream().forEach(e -> {e.setIndices(indices); e.setTypes(types);}); // TODO Do we still need this?
// parsedItem.getRatedRequests().stream().forEach(e ->
// {e.setIndices(indices); e.setTypes(types);});
assertNotSame(testItem, parsedItem); assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem); assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode()); assertEquals(testItem.hashCode(), parsedItem.hashCode());
} }
}
public void testSerialization() throws IOException { public void testSerialization() throws IOException {
RankEvalSpec original = createTestItem(); RankEvalSpec original = createTestItem();
@ -185,7 +150,7 @@ public class RankEvalSpecTests extends ESTestCase {
RankEvalTestHelper.copy(testItem, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables))); RankEvalTestHelper.copy(testItem, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables)));
} }
private RankEvalSpec mutateTestItem(RankEvalSpec mutant) { private static RankEvalSpec mutateTestItem(RankEvalSpec mutant) {
Collection<RatedRequest> ratedRequests = mutant.getRatedRequests(); Collection<RatedRequest> ratedRequests = mutant.getRatedRequests();
RankedListQualityMetric metric = mutant.getMetric(); RankedListQualityMetric metric = mutant.getMetric();
Map<String, Script> templates = mutant.getTemplates(); Map<String, Script> templates = mutant.getTemplates();
@ -193,7 +158,7 @@ public class RankEvalSpecTests extends ESTestCase {
int mutate = randomIntBetween(0, 2); int mutate = randomIntBetween(0, 2);
switch (mutate) { switch (mutate) {
case 0: case 0:
RatedRequest request = RatedRequestsTests.createTestItem(new ArrayList<>(), new ArrayList<>()); RatedRequest request = RatedRequestsTests.createTestItem(new ArrayList<>(), new ArrayList<>(), true);
ratedRequests.add(request); ratedRequests.add(request);
break; break;
case 1: case 1:
@ -205,13 +170,8 @@ public class RankEvalSpecTests extends ESTestCase {
break; break;
case 2: case 2:
if (templates.size() > 0) { if (templates.size() > 0) {
if (randomBoolean()) {
templates = null;
} else {
String mutatedTemplate = randomAsciiOfLength(10); String mutatedTemplate = randomAsciiOfLength(10);
templates.put("mutation", new Script(ScriptType.INLINE, "mustache", mutatedTemplate, new HashMap<>())); templates.put("mutation", new Script(ScriptType.INLINE, "mustache", mutatedTemplate, new HashMap<>()));
}
} else { } else {
String mutatedTemplate = randomValueOtherThanMany(templates::containsValue, () -> randomAsciiOfLength(10)); String mutatedTemplate = randomValueOtherThanMany(templates::containsValue, () -> randomAsciiOfLength(10));
templates.put("mutation", new Script(ScriptType.INLINE, "mustache", mutatedTemplate, new HashMap<>())); templates.put("mutation", new Script(ScriptType.INLINE, "mustache", mutatedTemplate, new HashMap<>()));
@ -238,7 +198,7 @@ public class RankEvalSpecTests extends ESTestCase {
public void testMissingMetricFailsParsing() { public void testMissingMetricFailsParsing() {
List<String> strings = Arrays.asList("value"); List<String> strings = Arrays.asList("value");
List<RatedRequest> ratedRequests = randomList(() -> RatedRequestsTests.createTestItem(strings, strings)); List<RatedRequest> ratedRequests = randomList(() -> RatedRequestsTests.createTestItem(strings, strings, randomBoolean()));
expectThrows(IllegalStateException.class, () -> new RankEvalSpec(ratedRequests, null)); expectThrows(IllegalStateException.class, () -> new RankEvalSpec(ratedRequests, null));
} }

View File

@ -25,13 +25,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
@ -45,17 +38,6 @@ import static org.junit.Assert.assertTrue;
public class RankEvalTestHelper { public class RankEvalTestHelper {
public static XContentParser roundtrip(ToXContent testItem) throws IOException {
XContentBuilder builder = XContentFactory.contentBuilder(ESTestCase.randomFrom(XContentType.values()));
if (ESTestCase.randomBoolean()) {
builder.prettyPrint();
}
testItem.toXContent(builder, ToXContent.EMPTY_PARAMS);
XContentBuilder shuffled = ESTestCase.shuffleXContent(builder);
XContentParser itemParser = XContentHelper.createParser(shuffled.bytes());
return itemParser;
}
public static <T> void testHashCodeAndEquals(T testItem, T mutation, T secondCopy) { public static <T> void testHashCodeAndEquals(T testItem, T mutation, T secondCopy) {
assertFalse("testItem is equal to null", testItem.equals(null)); assertFalse("testItem is equal to null", testItem.equals(null));
assertFalse("testItem is equal to incompatible type", testItem.equals("")); assertFalse("testItem is equal to incompatible type", testItem.equals(""));

View File

@ -24,7 +24,6 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase;
import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException;
import java.io.IOException; import java.io.IOException;
@ -34,7 +33,7 @@ public class RankEvalYamlIT extends ESClientYamlSuiteTestCase {
} }
@ParametersFactory @ParametersFactory
public static Iterable<Object[]> parameters() throws IOException, ClientYamlTestParseException { public static Iterable<Object[]> parameters() throws IOException {
return ESClientYamlSuiteTestCase.createParameters(); return ESClientYamlSuiteTestCase.createParameters();
} }
} }

View File

@ -19,8 +19,11 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import java.io.IOException; import java.io.IOException;
@ -38,12 +41,15 @@ public class RatedDocumentTests extends ESTestCase {
public void testXContentParsing() throws IOException { public void testXContentParsing() throws IOException {
RatedDocument testItem = createRatedDocument(); RatedDocument testItem = createRatedDocument();
XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem); XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
RatedDocument parsedItem = RatedDocument.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT); XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS));
try (XContentParser itemParser = createParser(shuffled)) {
RatedDocument parsedItem = RatedDocument.fromXContent(itemParser);
assertNotSame(testItem, parsedItem); assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem); assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode()); assertEquals(testItem.hashCode(), parsedItem.hashCode());
} }
}
public void testSerialization() throws IOException { public void testSerialization() throws IOException {
RatedDocument original = createRatedDocument(); RatedDocument original = createRatedDocument();

View File

@ -19,21 +19,19 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ParseFieldRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.indices.query.IndicesQueriesRegistry;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.SearchRequestParsers;
import org.elasticsearch.search.aggregations.AggregatorParsers;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.suggest.Suggesters;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.BeforeClass; import org.junit.BeforeClass;
@ -44,36 +42,38 @@ import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Function;
import java.util.stream.Stream;
import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.toList;
public class RatedRequestsTests extends ESTestCase { public class RatedRequestsTests extends ESTestCase {
private static SearchModule searchModule; private static NamedXContentRegistry xContentRegistry;
private static SearchRequestParsers searchRequestParsers;
/** /**
* setup for the whole base test class * setup for the whole base test class
*/ */
@BeforeClass @BeforeClass
public static void init() throws IOException { public static void init() {
AggregatorParsers aggsParsers = new AggregatorParsers(new ParseFieldRegistry<>("aggregation"), xContentRegistry = new NamedXContentRegistry(Stream.of(
new ParseFieldRegistry<>("aggregation_pipes")); new SearchModule(Settings.EMPTY, false, emptyList()).getNamedXContents().stream()
searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); ).flatMap(Function.identity()).collect(toList()));
IndicesQueriesRegistry queriesRegistry = searchModule.getQueryParserRegistry();
Suggesters suggesters = searchModule.getSuggesters();
searchRequestParsers = new SearchRequestParsers(queriesRegistry, aggsParsers, suggesters, null);
} }
@AfterClass @AfterClass
public static void afterClass() throws Exception { public static void afterClass() throws Exception {
searchModule = null; xContentRegistry = null;
searchRequestParsers = null;
} }
public static RatedRequest createTestItem(List<String> indices, List<String> types) { @Override
String requestId = randomAsciiOfLength(50); protected NamedXContentRegistry xContentRegistry() {
return xContentRegistry;
}
public static RatedRequest createTestItem(List<String> indices, List<String> types, boolean forceRequest) {
String requestId = randomAsciiOfLength(50);
List<RatedDocument> ratedDocs = new ArrayList<>(); List<RatedDocument> ratedDocs = new ArrayList<>();
int size = randomIntBetween(0, 2); int size = randomIntBetween(0, 2);
@ -83,15 +83,15 @@ public class RatedRequestsTests extends ESTestCase {
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();
SearchSourceBuilder testRequest = null; SearchSourceBuilder testRequest = null;
if (randomBoolean()) { if (randomBoolean() || forceRequest) {
testRequest = new SearchSourceBuilder();
testRequest.size(randomInt());
testRequest.query(new MatchAllQueryBuilder());
} else {
int randomSize = randomIntBetween(1, 10); int randomSize = randomIntBetween(1, 10);
for (int i = 0; i < randomSize; i++) { for (int i = 0; i < randomSize; i++) {
params.put(randomAsciiOfLengthBetween(1, 10), randomAsciiOfLengthBetween(1, 10)); params.put(randomAsciiOfLengthBetween(1, 10), randomAsciiOfLengthBetween(1, 10));
} }
} else {
testRequest = new SearchSourceBuilder();
testRequest.size(randomInt());
testRequest.query(new MatchAllQueryBuilder());
} }
List<String> summaryFields = new ArrayList<>(); List<String> summaryFields = new ArrayList<>();
@ -112,8 +112,6 @@ public class RatedRequestsTests extends ESTestCase {
ratedRequest.setTypes(types); ratedRequest.setTypes(types);
ratedRequest.setSummaryFields(summaryFields); ratedRequest.setSummaryFields(summaryFields);
} }
return ratedRequest; return ratedRequest;
} }
@ -130,21 +128,23 @@ public class RatedRequestsTests extends ESTestCase {
types.add(randomAsciiOfLengthBetween(0, 50)); types.add(randomAsciiOfLengthBetween(0, 50));
} }
RatedRequest testItem = createTestItem(indices, types); RatedRequest testItem = createTestItem(indices, types, randomBoolean());
XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem); XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS));
try (XContentParser itemParser = createParser(shuffled)) {
itemParser.nextToken(); itemParser.nextToken();
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT); RatedRequest parsedItem = RatedRequest.fromXContent(itemParser);
RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext, parsedItem.setIndices(indices); // IRL these come from URL
searchRequestParsers, null); // parameters - see
// RestRankEvalAction
RatedRequest parsedItem = RatedRequest.fromXContent(itemParser, rankContext); parsedItem.setTypes(types); // IRL these come from URL parameters -
parsedItem.setIndices(indices); // IRL these come from URL parameters - see RestRankEvalAction // see RestRankEvalAction
parsedItem.setTypes(types); // IRL these come from URL parameters - see RestRankEvalAction
assertNotSame(testItem, parsedItem); assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem); assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode()); assertEquals(testItem.hashCode(), parsedItem.hashCode());
} }
}
public void testSerialization() throws IOException { public void testSerialization() throws IOException {
List<String> indices = new ArrayList<>(); List<String> indices = new ArrayList<>();
@ -159,7 +159,7 @@ public class RatedRequestsTests extends ESTestCase {
types.add(randomAsciiOfLengthBetween(0, 50)); types.add(randomAsciiOfLengthBetween(0, 50));
} }
RatedRequest original = createTestItem(indices, types); RatedRequest original = createTestItem(indices, types, randomBoolean());
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(); List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
@ -183,7 +183,7 @@ public class RatedRequestsTests extends ESTestCase {
types.add(randomAsciiOfLengthBetween(0, 50)); types.add(randomAsciiOfLengthBetween(0, 50));
} }
RatedRequest testItem = createTestItem(indices, types); RatedRequest testItem = createTestItem(indices, types, randomBoolean());
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(); List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
@ -338,11 +338,8 @@ public class RatedRequestsTests extends ESTestCase {
+ " {\"_type\": \"testtype\", \"_index\": \"test\", \"_id\": \"2\", \"rating\" : 0 }, " + " {\"_type\": \"testtype\", \"_index\": \"test\", \"_id\": \"2\", \"rating\" : 0 }, "
+ " {\"_id\": \"3\", \"_index\": \"test\", \"_type\": \"testtype\", \"rating\" : 1 }]\n" + " {\"_id\": \"3\", \"_index\": \"test\", \"_type\": \"testtype\", \"rating\" : 1 }]\n"
+ "}"; + "}";
XContentParser parser = XContentFactory.xContent(querySpecString).createParser(querySpecString); try (XContentParser parser = createParser(JsonXContent.jsonXContent, querySpecString)) {
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, parser, ParseFieldMatcher.STRICT); RatedRequest specification = RatedRequest.fromXContent(parser);
RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext,
searchRequestParsers, null);
RatedRequest specification = RatedRequest.fromXContent(parser, rankContext);
assertEquals("my_qa_query", specification.getId()); assertEquals("my_qa_query", specification.getId());
assertNotNull(specification.getTestRequest()); assertNotNull(specification.getTestRequest());
List<RatedDocument> ratedDocs = specification.getRatedDocs(); List<RatedDocument> ratedDocs = specification.getRatedDocs();
@ -358,4 +355,5 @@ public class RatedRequestsTests extends ESTestCase {
} }
} }
} }
}
} }

View File

@ -19,9 +19,12 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.text.Text; import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.Index; import org.elasticsearch.index.Index;
import org.elasticsearch.index.rankeval.PrecisionTests.Rating; import org.elasticsearch.index.rankeval.PrecisionTests.Rating;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
@ -124,14 +127,17 @@ public class ReciprocalRankTests extends ESTestCase {
public void testXContentRoundtrip() throws IOException { public void testXContentRoundtrip() throws IOException {
ReciprocalRank testItem = createTestItem(); ReciprocalRank testItem = createTestItem();
XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem); XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS));
try (XContentParser itemParser = createParser(shuffled)) {
itemParser.nextToken(); itemParser.nextToken();
itemParser.nextToken(); itemParser.nextToken();
ReciprocalRank parsedItem = ReciprocalRank.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT); ReciprocalRank parsedItem = ReciprocalRank.fromXContent(itemParser);
assertNotSame(testItem, parsedItem); assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem); assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode()); assertEquals(testItem.hashCode(), parsedItem.hashCode());
} }
}
/** /**
* Create InternalSearchHits for testing, starting from dociId 'from' up to docId 'to'. * Create InternalSearchHits for testing, starting from dociId 'from' up to docId 'to'.

View File

@ -17,6 +17,7 @@
* under the License. * under the License.
*/ */
apply plugin: 'elasticsearch.standalone-rest-test'
apply plugin: 'elasticsearch.rest-test' apply plugin: 'elasticsearch.rest-test'

View File

@ -24,7 +24,6 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase;
import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException;
import java.io.IOException; import java.io.IOException;
@ -35,7 +34,7 @@ public class SmokeTestRankEvalWithMustacheYAMLTestSuiteIT extends ESClientYamlSu
} }
@ParametersFactory @ParametersFactory
public static Iterable<Object[]> parameters() throws IOException, ClientYamlTestParseException { public static Iterable<Object[]> parameters() throws IOException {
return ESClientYamlSuiteTestCase.createParameters(); return ESClientYamlSuiteTestCase.createParameters();
} }