diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/Request.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/Request.java index 487aecc7dd3..15f51d5e41b 100755 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/Request.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/Request.java @@ -63,6 +63,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.VersionType; +import org.elasticsearch.index.rankeval.RankEvalRequest; import org.elasticsearch.rest.action.search.RestSearchAction; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; @@ -71,6 +72,7 @@ import java.io.IOException; import java.nio.charset.Charset; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -484,6 +486,16 @@ public final class Request { return new Request(HttpHead.METHOD_NAME, endpoint, params.getParams(), null); } + static Request rankEval(RankEvalRequest rankEvalRequest) throws IOException { + // TODO maybe indices should be propery of RankEvalRequest and not of the spec + List indices = rankEvalRequest.getRankEvalSpec().getIndices(); + String endpoint = endpoint(indices.toArray(new String[indices.size()]), Strings.EMPTY_ARRAY, "_rank_eval"); + HttpEntity entity = null; + entity = createEntity(rankEvalRequest.getRankEvalSpec(), REQUEST_BODY_CONTENT_TYPE); + return new Request(HttpGet.METHOD_NAME, endpoint, Collections.emptyMap(), entity); + + } + private static HttpEntity createEntity(ToXContent toXContent, XContentType xContentType) throws IOException { BytesRef source = XContentHelper.toXContent(toXContent, xContentType, false).toBytesRef(); return new ByteArrayEntity(source.bytes, source.offset, source.length, createContentType(xContentType)); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java index 9fb53a54d8c..c8e248657dd 100755 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java @@ -54,6 +54,8 @@ import org.elasticsearch.common.xcontent.ContextParser; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.rankeval.RankEvalRequest; +import org.elasticsearch.index.rankeval.RankEvalResponse; import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestStatus; @@ -467,6 +469,27 @@ public class RestHighLevelClient implements Closeable { listener, emptySet(), headers); } + /** + * Executes a request using the Ranking Evaluation API. + * + * See Ranking Evaluation API + * on elastic.co + */ + public final RankEvalResponse rankEval(RankEvalRequest rankEvalRequest, Header... headers) throws IOException { + return performRequestAndParseEntity(rankEvalRequest, Request::rankEval, RankEvalResponse::fromXContent, emptySet(), headers); + } + + /** + * Asynchronously executes a request using the Ranking Evaluation API. + * + * See Ranking Evaluation API + * on elastic.co + */ + public final void rankEvalAsync(RankEvalRequest rankEvalRequest, ActionListener listener, Header... headers) { + performRequestAsyncAndParseEntity(rankEvalRequest, Request::rankEval, RankEvalResponse::fromXContent, listener, emptySet(), + headers); + } + protected final Resp performRequestAndParseEntity(Req request, CheckedFunction requestConverter, CheckedFunction entityParser, diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java new file mode 100644 index 00000000000..c65f7e9da5b --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java @@ -0,0 +1,120 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client; + +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.rankeval.EvalQueryQuality; +import org.elasticsearch.index.rankeval.PrecisionAtK; +import org.elasticsearch.index.rankeval.RankEvalRequest; +import org.elasticsearch.index.rankeval.RankEvalResponse; +import org.elasticsearch.index.rankeval.RankEvalSpec; +import org.elasticsearch.index.rankeval.RatedDocument; +import org.elasticsearch.index.rankeval.RatedRequest; +import org.elasticsearch.index.rankeval.RatedSearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map.Entry; +import java.util.Set; + +import static org.elasticsearch.index.rankeval.EvaluationMetric.filterUnknownDocuments; + +public class RankEvalIT extends ESRestHighLevelClientTestCase { + + @Before + public void indexDocuments() throws IOException { + StringEntity doc = new StringEntity("{\"text\":\"berlin\"}", ContentType.APPLICATION_JSON); + client().performRequest("PUT", "/index/doc/1", Collections.emptyMap(), doc); + doc = new StringEntity("{\"text\":\"amsterdam\"}", ContentType.APPLICATION_JSON); + client().performRequest("PUT", "/index/doc/2", Collections.emptyMap(), doc); + client().performRequest("PUT", "/index/doc/3", Collections.emptyMap(), doc); + client().performRequest("PUT", "/index/doc/4", Collections.emptyMap(), doc); + client().performRequest("PUT", "/index/doc/5", Collections.emptyMap(), doc); + client().performRequest("PUT", "/index/doc/6", Collections.emptyMap(), doc); + client().performRequest("POST", "/index/_refresh"); + } + + /** + * Test cases retrieves all six documents indexed above and checks the Prec@10 + * calculation where all unlabeled documents are treated as not relevant. + */ + public void testRankEvalRequest() throws IOException { + SearchSourceBuilder testQuery = new SearchSourceBuilder(); + testQuery.query(new MatchAllQueryBuilder()); + RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", createRelevant("index" , "2", "3", "4", "5"), testQuery); + RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "1"), testQuery); + List specifications = new ArrayList<>(); + specifications.add(amsterdamRequest); + specifications.add(berlinRequest); + PrecisionAtK metric = new PrecisionAtK(1, false, 10); + RankEvalSpec spec = new RankEvalSpec(specifications, metric); + spec.addIndices(Collections.singletonList("index")); + + RankEvalResponse response = execute(new RankEvalRequest(spec), highLevelClient()::rankEval, highLevelClient()::rankEvalAsync); + // the expected Prec@ for the first query is 4/6 and the expected Prec@ for the second is 1/6, divided by 2 to get the average + double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0; + assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE); + Set> entrySet = response.getPartialResults().entrySet(); + assertEquals(2, entrySet.size()); + for (Entry entry : entrySet) { + EvalQueryQuality quality = entry.getValue(); + if (entry.getKey() == "amsterdam_query") { + assertEquals(2, filterUnknownDocuments(quality.getHitsAndRatings()).size()); + List hitsAndRatings = quality.getHitsAndRatings(); + assertEquals(6, hitsAndRatings.size()); + for (RatedSearchHit hit : hitsAndRatings) { + String id = hit.getSearchHit().getId(); + if (id.equals("1") || id.equals("6")) { + assertFalse(hit.getRating().isPresent()); + } else { + assertEquals(1, hit.getRating().get().intValue()); + } + } + } + if (entry.getKey() == "berlin_query") { + assertEquals(5, filterUnknownDocuments(quality.getHitsAndRatings()).size()); + List hitsAndRatings = quality.getHitsAndRatings(); + assertEquals(6, hitsAndRatings.size()); + for (RatedSearchHit hit : hitsAndRatings) { + String id = hit.getSearchHit().getId(); + if (id.equals("1")) { + assertEquals(1, hit.getRating().get().intValue()); + } else { + assertFalse(hit.getRating().isPresent()); + } + } + } + } + } + + private static List createRelevant(String indexName, String... docs) { + List relevant = new ArrayList<>(); + for (String doc : docs) { + relevant.add(new RatedDocument(indexName, doc, 1)); + } + return relevant; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RequestTests.java index 6a27db2797a..1c36d2a2219 100755 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RequestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RequestTests.java @@ -70,6 +70,11 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.VersionType; import org.elasticsearch.index.query.TermQueryBuilder; +import org.elasticsearch.index.rankeval.PrecisionAtK; +import org.elasticsearch.index.rankeval.RankEvalRequest; +import org.elasticsearch.index.rankeval.RankEvalSpec; +import org.elasticsearch.index.rankeval.RatedRequest; +import org.elasticsearch.index.rankeval.RestRankEvalAction; import org.elasticsearch.rest.action.search.RestSearchAction; import org.elasticsearch.search.Scroll; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; @@ -89,6 +94,8 @@ import java.io.InputStream; import java.lang.reflect.Constructor; import java.lang.reflect.Modifier; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -1025,6 +1032,26 @@ public class RequestTests extends ESTestCase { IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> Request.existsAlias(getAliasesRequest)); assertEquals("existsAlias requires at least an alias or an index", iae.getMessage()); } + + public void testRankEval() throws Exception { + RankEvalSpec spec = new RankEvalSpec( + Collections.singletonList(new RatedRequest("queryId", Collections.emptyList(), new SearchSourceBuilder())), + new PrecisionAtK()); + String[] indices = randomIndicesNames(0, 5); + spec.addIndices(Arrays.asList(indices)); + RankEvalRequest rankEvalRequest = new RankEvalRequest(spec); + + Request request = Request.rankEval(rankEvalRequest); + StringJoiner endpoint = new StringJoiner("/", "/", ""); + String index = String.join(",", indices); + if (Strings.hasLength(index)) { + endpoint.add(index); + } + endpoint.add(RestRankEvalAction.ENDPOINT); + assertEquals(endpoint.toString(), request.getEndpoint()); + assertEquals(Collections.emptyMap(), request.getParameters()); + assertToXContentBody(spec, request.getEntity()); + } private static void assertToXContentBody(ToXContent expectedBody, HttpEntity actualEntity) throws IOException { BytesReference expectedBytes = XContentHelper.toXContent(expectedBody, REQUEST_BODY_CONTENT_TYPE, false); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index ace773239f8..79b8d8d1503 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -64,6 +64,7 @@ import org.elasticsearch.common.xcontent.smile.SmileXContent; import org.elasticsearch.index.rankeval.DiscountedCumulativeGain; import org.elasticsearch.index.rankeval.EvaluationMetric; import org.elasticsearch.index.rankeval.MeanReciprocalRank; +import org.elasticsearch.index.rankeval.MetricDetail; import org.elasticsearch.index.rankeval.PrecisionAtK; import org.elasticsearch.join.aggregations.ChildrenAggregationBuilder; import org.elasticsearch.rest.RestStatus; @@ -656,7 +657,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(5, namedXContents.size()); + assertEquals(7, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -666,7 +667,7 @@ public class RestHighLevelClientTests extends ESTestCase { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals(2, categories.size()); + assertEquals(3, categories.size()); assertEquals(Integer.valueOf(2), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -674,6 +675,9 @@ public class RestHighLevelClientTests extends ESTestCase { assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); + assertEquals(Integer.valueOf(2), categories.get(MetricDetail.class)); + assertTrue(names.contains(PrecisionAtK.NAME)); + assertTrue(names.contains(MeanReciprocalRank.NAME)); } private static class TrackingActionListener implements ActionListener { diff --git a/docs/reference/search/rank-eval.asciidoc b/docs/reference/search/rank-eval.asciidoc index 6e834d5e60c..53c6ac9cf60 100644 --- a/docs/reference/search/rank-eval.asciidoc +++ b/docs/reference/search/rank-eval.asciidoc @@ -283,8 +283,10 @@ that shows potential errors of individual queries. The response has the followin }, [...] ], "metric_details": { <6> - "relevant_docs_retrieved": 6, - "docs_retrieved": 10 + "precision" : { + "relevant_docs_retrieved": 6, + "docs_retrieved": 10 + } } }, "my_query_id2" : { [...] } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java index 64d4ada0dc1..8ac09993b7c 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java @@ -164,7 +164,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric { private static final ParseField K_FIELD = new ParseField("k"); private static final ParseField NORMALIZE_FIELD = new ParseField("normalize"); private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating"); - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("dcg_at", + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("dcg_at", true, args -> { Boolean normalized = (Boolean) args[0]; Integer optK = (Integer) args[2]; diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvalQueryQuality.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvalQueryQuality.java index 28162c47441..c683c54bfdd 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvalQueryQuality.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvalQueryQuality.java @@ -19,11 +19,15 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentFragment; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey; import java.io.IOException; @@ -34,22 +38,32 @@ import java.util.Objects;; /** * Result of the evaluation metric calculation on one particular query alone. */ -public class EvalQueryQuality implements ToXContent, Writeable { +public class EvalQueryQuality implements ToXContentFragment, Writeable { private final String queryId; private final double evaluationResult; - private MetricDetails optionalMetricDetails; - private final List ratedHits = new ArrayList<>(); + private MetricDetail optionalMetricDetails; + private final List ratedHits; public EvalQueryQuality(String id, double evaluationResult) { this.queryId = id; this.evaluationResult = evaluationResult; + this.ratedHits = new ArrayList<>(); } public EvalQueryQuality(StreamInput in) throws IOException { - this(in.readString(), in.readDouble()); - this.ratedHits.addAll(in.readList(RatedSearchHit::new)); - this.optionalMetricDetails = in.readOptionalNamedWriteable(MetricDetails.class); + this.queryId = in.readString(); + this.evaluationResult = in.readDouble(); + this.ratedHits = in.readList(RatedSearchHit::new); + this.optionalMetricDetails = in.readOptionalNamedWriteable(MetricDetail.class); + } + + // only used for parsing internally + private EvalQueryQuality(String queryId, ParsedEvalQueryQuality builder) { + this.queryId = queryId; + this.evaluationResult = builder.evaluationResult; + this.optionalMetricDetails = builder.optionalMetricDetails; + this.ratedHits = builder.ratedHits; } @Override @@ -68,11 +82,11 @@ public class EvalQueryQuality implements ToXContent, Writeable { return evaluationResult; } - public void setMetricDetails(MetricDetails breakdown) { + public void setMetricDetails(MetricDetail breakdown) { this.optionalMetricDetails = breakdown; } - public MetricDetails getMetricDetails() { + public MetricDetail getMetricDetails() { return this.optionalMetricDetails; } @@ -87,8 +101,8 @@ public class EvalQueryQuality implements ToXContent, Writeable { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(queryId); - builder.field("quality_level", this.evaluationResult); - builder.startArray("unknown_docs"); + builder.field(QUALITY_LEVEL_FIELD.getPreferredName(), this.evaluationResult); + builder.startArray(UNKNOWN_DOCS_FIELD.getPreferredName()); for (DocumentKey key : EvaluationMetric.filterUnknownDocuments(ratedHits)) { builder.startObject(); builder.field(RatedDocument.INDEX_FIELD.getPreferredName(), key.getIndex()); @@ -96,20 +110,50 @@ public class EvalQueryQuality implements ToXContent, Writeable { builder.endObject(); } builder.endArray(); - builder.startArray("hits"); + builder.startArray(HITS_FIELD.getPreferredName()); for (RatedSearchHit hit : ratedHits) { hit.toXContent(builder, params); } builder.endArray(); if (optionalMetricDetails != null) { - builder.startObject("metric_details"); - optionalMetricDetails.toXContent(builder, params); - builder.endObject(); + builder.field(METRIC_DETAILS_FIELD.getPreferredName(), optionalMetricDetails); } builder.endObject(); return builder; } + private static final ParseField QUALITY_LEVEL_FIELD = new ParseField("quality_level"); + private static final ParseField UNKNOWN_DOCS_FIELD = new ParseField("unknown_docs"); + private static final ParseField HITS_FIELD = new ParseField("hits"); + private static final ParseField METRIC_DETAILS_FIELD = new ParseField("metric_details"); + private static final ObjectParser PARSER = new ObjectParser<>("eval_query_quality", + true, ParsedEvalQueryQuality::new); + + private static class ParsedEvalQueryQuality { + double evaluationResult; + MetricDetail optionalMetricDetails; + List ratedHits = new ArrayList<>(); + } + + static { + PARSER.declareDouble((obj, value) -> obj.evaluationResult = value, QUALITY_LEVEL_FIELD); + PARSER.declareObject((obj, value) -> obj.optionalMetricDetails = value, (p, c) -> parseMetricDetail(p), + METRIC_DETAILS_FIELD); + PARSER.declareObjectArray((obj, list) -> obj.ratedHits = list, (p, c) -> RatedSearchHit.parse(p), HITS_FIELD); + } + + private static MetricDetail parseMetricDetail(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + MetricDetail metricDetail = parser.namedObject(MetricDetail.class, parser.currentName(), null); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return metricDetail; + } + + public static EvalQueryQuality fromXContent(XContentParser parser, String queryId) throws IOException { + return new EvalQueryQuality(queryId, PARSER.apply(parser, null)); + } + @Override public final boolean equals(Object obj) { if (this == obj) { diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java index a74fd8da3e6..ef510b399d4 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java @@ -32,6 +32,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings; @@ -180,9 +181,10 @@ public class MeanReciprocalRank implements EvaluationMetric { return Objects.hash(relevantRatingThreshhold, k); } - static class Breakdown implements MetricDetails { + static class Breakdown implements MetricDetail { private final int firstRelevantRank; + private static ParseField FIRST_RELEVANT_RANK_FIELD = new ParseField("first_relevant"); Breakdown(int firstRelevantRank) { this.firstRelevantRank = firstRelevantRank; @@ -193,10 +195,27 @@ public class MeanReciprocalRank implements EvaluationMetric { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) + public + String getMetricName() { + return NAME; + } + + @Override + public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("first_relevant", firstRelevantRank); - return builder; + return builder.field(FIRST_RELEVANT_RANK_FIELD.getPreferredName(), firstRelevantRank); + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, true, args -> { + return new Breakdown((Integer) args[0]); + }); + + static { + PARSER.declareInt(constructorArg(), FIRST_RELEVANT_RANK_FIELD); + } + + public static Breakdown fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); } @Override diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MetricDetails.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MetricDetail.java similarity index 55% rename from modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MetricDetails.java rename to modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MetricDetail.java index 22b0ed19cf2..bc95b03c8bd 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MetricDetails.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MetricDetail.java @@ -20,11 +20,31 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; /** * Details about a specific {@link EvaluationMetric} that should be included in the resonse. */ -public interface MetricDetails extends ToXContent, NamedWriteable { +public interface MetricDetail extends ToXContentObject, NamedWriteable { + @Override + default XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startObject(getMetricName()); + innerToXContent(builder, params); + builder.endObject(); + return builder.endObject(); + }; + + default String getMetricName() { + return getWriteableName(); + } + + /** + * Implementations should write their own fields to the {@link XContentBuilder} passed in. + */ + XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java index 63bdcb7307d..15d955935ee 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java @@ -34,6 +34,7 @@ import java.util.Optional; import javax.naming.directory.SearchResult; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings; @@ -216,10 +217,10 @@ public class PrecisionAtK implements EvaluationMetric { return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled, k); } - static class Breakdown implements MetricDetails { + static class Breakdown implements MetricDetail { - private static final String DOCS_RETRIEVED_FIELD = "docs_retrieved"; - private static final String RELEVANT_DOCS_RETRIEVED_FIELD = "relevant_docs_retrieved"; + private static final ParseField DOCS_RETRIEVED_FIELD = new ParseField("docs_retrieved"); + private static final ParseField RELEVANT_DOCS_RETRIEVED_FIELD = new ParseField("relevant_docs_retrieved"); private int relevantRetrieved; private int retrieved; @@ -234,13 +235,26 @@ public class PrecisionAtK implements EvaluationMetric { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) + public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(RELEVANT_DOCS_RETRIEVED_FIELD, relevantRetrieved); - builder.field(DOCS_RETRIEVED_FIELD, retrieved); + builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved); + builder.field(DOCS_RETRIEVED_FIELD.getPreferredName(), retrieved); return builder; } + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, true, args -> { + return new Breakdown((Integer) args[0], (Integer) args[1]); + }); + + static { + PARSER.declareInt(constructorArg(), RELEVANT_DOCS_RETRIEVED_FIELD); + PARSER.declareInt(constructorArg(), DOCS_RETRIEVED_FIELD); + } + + public static Breakdown fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeVInt(relevantRetrieved); diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java index ba241248a02..54d68774a01 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java @@ -37,6 +37,10 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider { MeanReciprocalRank::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME), DiscountedCumulativeGain::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME), + PrecisionAtK.Breakdown::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME), + MeanReciprocalRank.Breakdown::fromXContent)); return namedXContent; } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java index 6b3a23d07d5..d4ccd7c2180 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java @@ -60,9 +60,9 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin { namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, PrecisionAtK.NAME, PrecisionAtK.Breakdown::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Breakdown::new)); namedWriteables - .add(new NamedWriteableRegistry.Entry(MetricDetails.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new)); + .add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new)); return namedWriteables; } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalRequest.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalRequest.java index 637b9a18844..c682ec45ed6 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalRequest.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalRequest.java @@ -31,12 +31,19 @@ import java.io.IOException; */ public class RankEvalRequest extends ActionRequest { - private RankEvalSpec rankingEvaluation; + private RankEvalSpec rankingEvaluationSpec; + + public RankEvalRequest(RankEvalSpec rankingEvaluationSpec) { + this.rankingEvaluationSpec = rankingEvaluationSpec; + } + + RankEvalRequest() { + } @Override public ActionRequestValidationException validate() { ActionRequestValidationException e = null; - if (rankingEvaluation == null) { + if (rankingEvaluationSpec == null) { e = new ActionRequestValidationException(); e.addValidationError("missing ranking evaluation specification"); } @@ -47,26 +54,26 @@ public class RankEvalRequest extends ActionRequest { * Returns the specification of the ranking evaluation. */ public RankEvalSpec getRankEvalSpec() { - return rankingEvaluation; + return rankingEvaluationSpec; } /** * Set the the specification of the ranking evaluation. */ public void setRankEvalSpec(RankEvalSpec task) { - this.rankingEvaluation = task; + this.rankingEvaluationSpec = task; } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); - rankingEvaluation = new RankEvalSpec(in); + rankingEvaluationSpec = new RankEvalSpec(in); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - rankingEvaluation.writeTo(out); + rankingEvaluationSpec.writeTo(out); } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java index e8fe1827268..6dd3c1338fa 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java @@ -21,16 +21,24 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; /** * Returns the results for a {@link RankEvalRequest}.
@@ -121,11 +129,38 @@ public class RankEvalResponse extends ActionResponse implements ToXContentObject builder.startObject("failures"); for (String key : failures.keySet()) { builder.startObject(key); - ElasticsearchException.generateFailureXContent(builder, params, failures.get(key), false); + ElasticsearchException.generateFailureXContent(builder, params, failures.get(key), true); builder.endObject(); } builder.endObject(); builder.endObject(); return builder; } + + private static final ParseField QUALITY_LEVEL_FIELD = new ParseField("quality_level"); + private static final ParseField DETAILS_FIELD = new ParseField("details"); + private static final ParseField FAILURES_FIELD = new ParseField("failures"); + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("rank_eval_response", + true, + a -> new RankEvalResponse((Double) a[0], + ((List) a[1]).stream().collect(Collectors.toMap(EvalQueryQuality::getId, Function.identity())), + ((List>) a[2]).stream().collect(Collectors.toMap(Tuple::v1, Tuple::v2)))); + static { + PARSER.declareDouble(ConstructingObjectParser.constructorArg(), QUALITY_LEVEL_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> EvalQueryQuality.fromXContent(p, n), + DETAILS_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, p.nextToken(), p::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, p.nextToken(), p::getTokenLocation); + Tuple tuple = new Tuple<>(n, ElasticsearchException.failureFromXContent(p)); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, p.nextToken(), p::getTokenLocation); + return tuple; + }, FAILURES_FIELD); + + } + + public static RankEvalResponse fromXContent(XContentParser parser) throws IOException { + return PARSER.apply(parser, null); + } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedSearchHit.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedSearchHit.java index 11c76f7fb30..9d8f4cc33d6 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedSearchHit.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedSearchHit.java @@ -19,11 +19,16 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser.ValueType; import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchHit; import java.io.IOException; @@ -33,7 +38,7 @@ import java.util.Optional; /** * Combines a {@link SearchHit} with a document rating. */ -public class RatedSearchHit implements Writeable, ToXContent { +public class RatedSearchHit implements Writeable, ToXContentObject { private final SearchHit searchHit; private final Optional rating; @@ -75,6 +80,23 @@ public class RatedSearchHit implements Writeable, ToXContent { return builder; } + private static final ParseField HIT_FIELD = new ParseField("hit"); + private static final ParseField RATING_FIELD = new ParseField("rating"); + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("rated_hit", true, + a -> new RatedSearchHit((SearchHit) a[0], (Optional) a[1])); + + static { + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> SearchHit.fromXContent(p), HIT_FIELD); + PARSER.declareField(ConstructingObjectParser.constructorArg(), + (p) -> p.currentToken() == XContentParser.Token.VALUE_NULL ? Optional.empty() : Optional.of(p.intValue()), RATING_FIELD, + ValueType.INT_OR_NULL); + } + + public static RatedSearchHit parse(XContentParser parser) throws IOException { + return PARSER.apply(parser, null); + } + @Override public final boolean equals(Object obj) { if (this == obj) { diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RestRankEvalAction.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RestRankEvalAction.java index 1efe5b4e39b..a2c2aeb7584 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RestRankEvalAction.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RestRankEvalAction.java @@ -89,12 +89,14 @@ import static org.elasticsearch.rest.RestRequest.Method.POST; */ public class RestRankEvalAction extends BaseRestHandler { + public static String ENDPOINT = "_rank_eval"; + public RestRankEvalAction(Settings settings, RestController controller) { super(settings); - controller.registerHandler(GET, "/_rank_eval", this); - controller.registerHandler(POST, "/_rank_eval", this); - controller.registerHandler(GET, "/{index}/_rank_eval", this); - controller.registerHandler(POST, "/{index}/_rank_eval", this); + controller.registerHandler(GET, "/" + ENDPOINT, this); + controller.registerHandler(POST, "/" + ENDPOINT, this); + controller.registerHandler(GET, "/{index}/" + ENDPOINT, this); + controller.registerHandler(POST, "/{index}/" + ENDPOINT, this); } @Override diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java index 26ff5f3683e..df6de75ba2c 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java @@ -19,20 +19,38 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.Index; import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey; +import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Predicate; +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; +import static org.elasticsearch.test.XContentTestUtils.insertRandomFields; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; public class EvalQueryQualityTests extends ESTestCase { private static NamedWriteableRegistry namedWritableRegistry = new NamedWriteableRegistry(new RankEvalPlugin().getNamedWriteables()); + @SuppressWarnings("resource") + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new RankEvalPlugin().getNamedXContent()); + } + public static EvalQueryQuality randomEvalQueryQuality() { List unknownDocs = new ArrayList<>(); int numberOfUnknownDocs = randomInt(5); @@ -42,7 +60,10 @@ public class EvalQueryQualityTests extends ESTestCase { int numberOfSearchHits = randomInt(5); List ratedHits = new ArrayList<>(); for (int i = 0; i < numberOfSearchHits; i++) { - ratedHits.add(RatedSearchHitTests.randomRatedSearchHit()); + RatedSearchHit ratedSearchHit = RatedSearchHitTests.randomRatedSearchHit(); + // we need to associate each hit with an index name otherwise rendering will not work + ratedSearchHit.getSearchHit().shard(new SearchShardTarget("_na_", new Index("index", "_na_"), 0, null)); + ratedHits.add(ratedSearchHit); } EvalQueryQuality evalQueryQuality = new EvalQueryQuality(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, true)); @@ -65,6 +86,35 @@ public class EvalQueryQualityTests extends ESTestCase { assertNotSame(deserialized, original); } + public void testXContentParsing() throws IOException { + EvalQueryQuality testItem = randomEvalQueryQuality(); + boolean humanReadable = randomBoolean(); + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference originalBytes = toShuffledXContent(testItem, xContentType, ToXContent.EMPTY_PARAMS, humanReadable); + // skip inserting random fields for: + // - the root object, since we expect a particular queryId there in this test + // - the `metric_details` section, which can potentially contain different namedXContent names + // - everything under `hits` (we test lenient SearchHit parsing elsewhere) + Predicate pathsToExclude = path -> path.isEmpty() || path.endsWith("metric_details") || path.contains("hits"); + BytesReference withRandomFields = insertRandomFields(xContentType, originalBytes, pathsToExclude, random()); + EvalQueryQuality parsedItem; + try (XContentParser parser = createParser(xContentType.xContent(), withRandomFields)) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); + ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + String queryId = parser.currentName(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation); + parsedItem = EvalQueryQuality.fromXContent(parser, queryId); + ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.currentToken(), parser::getTokenLocation); + ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + assertNull(parser.nextToken()); + } + assertNotSame(testItem, parsedItem); + // we cannot check equality of object here because some information (e.g. SearchHit#shard) cannot fully be + // parsed back after going through the rest layer. That's why we only check that the original and the parsed item + // have the same xContent representation + assertToXContentEquivalent(originalBytes, toXContent(parsedItem, xContentType, humanReadable), xContentType); + } + private static EvalQueryQuality copy(EvalQueryQuality original) throws IOException { return ESTestCase.copyWriteable(original, namedWritableRegistry, EvalQueryQuality::new); } @@ -77,7 +127,7 @@ public class EvalQueryQualityTests extends ESTestCase { String id = original.getId(); double qualityLevel = original.getQualityLevel(); List ratedHits = new ArrayList<>(original.getHitsAndRatings()); - MetricDetails metricDetails = original.getMetricDetails(); + MetricDetail metricDetails = original.getMetricDetails(); switch (randomIntBetween(0, 3)) { case 0: id = id + "_"; diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java index 881b9e04709..26492d3566f 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java @@ -19,7 +19,13 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.text.Text; @@ -27,11 +33,15 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentLocation; +import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.discovery.DiscoverySettings; import org.elasticsearch.index.Index; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchParseException; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.TestSearchContext; import java.io.IOException; import java.util.ArrayList; @@ -41,9 +51,27 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Predicate; + +import static java.util.Collections.singleton; +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.test.XContentTestUtils.insertRandomFields; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.hamcrest.Matchers.instanceOf; public class RankEvalResponseTests extends ESTestCase { + private static final Exception[] RANDOM_EXCEPTIONS = new Exception[] { + new ClusterBlockException(singleton(DiscoverySettings.NO_MASTER_BLOCK_WRITES)), + new CircuitBreakingException("Data too large", 123, 456), + new SearchParseException(new TestSearchContext(null), "Parse failure", new XContentLocation(12, 98)), + new IllegalArgumentException("Closed resource", new RuntimeException("Resource")), + new SearchPhaseExecutionException("search", "all shards failed", + new ShardSearchFailure[] { new ShardSearchFailure(new ParsingException(1, 2, "foobar", null), + new SearchShardTarget("node_1", new Index("foo", "_na_"), 1, null)) }), + new ElasticsearchException("Parsing failed", + new ParsingException(9, 42, "Wrong state", new NullPointerException("Unexpected null value"))) }; + private static RankEvalResponse createRandomResponse() { int numberOfRequests = randomIntBetween(0, 5); Map partials = new HashMap<>(numberOfRequests); @@ -62,8 +90,7 @@ public class RankEvalResponseTests extends ESTestCase { int numberOfErrors = randomIntBetween(0, 2); Map errors = new HashMap<>(numberOfRequests); for (int i = 0; i < numberOfErrors; i++) { - errors.put(randomAlphaOfLengthBetween(3, 10), - new IllegalArgumentException(randomAlphaOfLength(10))); + errors.put(randomAlphaOfLengthBetween(3, 10), randomFrom(RANDOM_EXCEPTIONS)); } return new RankEvalResponse(randomDouble(), partials, errors); } @@ -84,6 +111,41 @@ public class RankEvalResponseTests extends ESTestCase { } } + public void testXContentParsing() throws IOException { + RankEvalResponse testItem = createRandomResponse(); + boolean humanReadable = randomBoolean(); + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference originalBytes = toShuffledXContent(testItem, xContentType, ToXContent.EMPTY_PARAMS, humanReadable); + // skip inserting random fields for: + // - the `details` section, which can contain arbitrary queryIds + // - everything under `failures` (exceptions parsing is quiet lenient) + // - everything under `hits` (we test lenient SearchHit parsing elsewhere) + Predicate pathsToExclude = path -> (path.endsWith("details") || path.contains("failures") || path.contains("hits")); + BytesReference withRandomFields = insertRandomFields(xContentType, originalBytes, pathsToExclude, random()); + RankEvalResponse parsedItem; + try (XContentParser parser = createParser(xContentType.xContent(), withRandomFields)) { + parsedItem = RankEvalResponse.fromXContent(parser); + assertNull(parser.nextToken()); + } + assertNotSame(testItem, parsedItem); + // We cannot check equality of object here because some information (e.g. + // SearchHit#shard) cannot fully be parsed back. + assertEquals(testItem.getEvaluationResult(), parsedItem.getEvaluationResult(), 0.0); + assertEquals(testItem.getPartialResults().keySet(), parsedItem.getPartialResults().keySet()); + for (EvalQueryQuality metricDetail : testItem.getPartialResults().values()) { + EvalQueryQuality parsedEvalQueryQuality = parsedItem.getPartialResults().get(metricDetail.getId()); + assertToXContentEquivalent(toXContent(metricDetail, xContentType, humanReadable), + toXContent(parsedEvalQueryQuality, xContentType, humanReadable), xContentType); + } + // Also exceptions that are parsed back will be different since they are re-wrapped during parsing. + // However, we can check that there is the expected number + assertEquals(testItem.getFailures().keySet(), parsedItem.getFailures().keySet()); + for (String queryId : testItem.getFailures().keySet()) { + Exception ex = parsedItem.getFailures().get(queryId); + assertThat(ex, instanceOf(ElasticsearchException.class)); + } + } + public void testToXContent() throws IOException { EvalQueryQuality coffeeQueryQuality = new EvalQueryQuality("coffee_query", 0.1); coffeeQueryQuality.addHitsAndRatings(Arrays.asList(searchHit("index", 123, 5), searchHit("index", 456, null))); @@ -106,7 +168,11 @@ public class RankEvalResponseTests extends ESTestCase { " }," + " \"failures\": {" + " \"beer_query\": {" + - " \"error\": \"ParsingException[someMsg]\"" + + " \"error\" : {\"root_cause\": [{\"type\":\"parsing_exception\", \"reason\":\"someMsg\",\"line\":0,\"col\":0}]," + + " \"type\":\"parsing_exception\"," + + " \"reason\":\"someMsg\"," + + " \"line\":0,\"col\":0" + + " }" + " }" + " }" + "}").replaceAll("\\s+", ""), xContent); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedSearchHitTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedSearchHitTests.java index cf66b0b7797..622c49a9886 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedSearchHitTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedSearchHitTests.java @@ -19,8 +19,12 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.text.Text; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.search.SearchHit; import org.elasticsearch.test.ESTestCase; @@ -29,6 +33,7 @@ import java.util.Collections; import java.util.Optional; import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; +import static org.elasticsearch.test.XContentTestUtils.insertRandomFields; public class RatedSearchHitTests extends ESTestCase { @@ -66,6 +71,19 @@ public class RatedSearchHitTests extends ESTestCase { assertNotSame(deserialized, original); } + public void testXContentRoundtrip() throws IOException { + RatedSearchHit testItem = randomRatedSearchHit(); + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference originalBytes = toShuffledXContent(testItem, xContentType, ToXContent.EMPTY_PARAMS, randomBoolean()); + BytesReference withRandomFields = insertRandomFields(xContentType, originalBytes, null, random()); + try (XContentParser parser = createParser(xContentType.xContent(), withRandomFields)) { + RatedSearchHit parsedItem = RatedSearchHit.parse(parser); + assertNotSame(testItem, parsedItem); + assertEquals(testItem, parsedItem); + assertEquals(testItem.hashCode(), parsedItem.hashCode()); + } + } + public void testEqualsAndHash() throws IOException { checkEqualsAndHashCode(randomRatedSearchHit(), RatedSearchHitTests::copy, RatedSearchHitTests::mutateTestItem); } diff --git a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml index 4a244dcb9e5..3481d17dceb 100644 --- a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml +++ b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml @@ -67,7 +67,7 @@ - match: { quality_level: 1} - match: { details.amsterdam_query.quality_level: 1.0} - match: { details.amsterdam_query.unknown_docs: [ {"_index": "foo", "_id": "doc4"}]} - - match: { details.amsterdam_query.metric_details: {"relevant_docs_retrieved": 2, "docs_retrieved": 2}} + - match: { details.amsterdam_query.metric_details.precision: {"relevant_docs_retrieved": 2, "docs_retrieved": 2}} - length: { details.amsterdam_query.hits: 3} - match: { details.amsterdam_query.hits.0.hit._id: "doc2"} @@ -79,7 +79,7 @@ - match: { details.berlin_query.quality_level: 1.0} - match: { details.berlin_query.unknown_docs: [ {"_index": "foo", "_id": "doc4"}]} - - match: { details.berlin_query.metric_details: {"relevant_docs_retrieved": 1, "docs_retrieved": 1}} + - match: { details.berlin_query.metric_details.precision: {"relevant_docs_retrieved": 1, "docs_retrieved": 1}} - length: { details.berlin_query.hits: 2} - match: { details.berlin_query.hits.0.hit._id: "doc1" } - match: { details.berlin_query.hits.0.rating: 1} @@ -156,10 +156,10 @@ - lt: {quality_level: 0.417} - gt: {details.amsterdam_query.quality_level: 0.333} - lt: {details.amsterdam_query.quality_level: 0.334} - - match: {details.amsterdam_query.metric_details: {"first_relevant": 3}} + - match: {details.amsterdam_query.metric_details.mean_reciprocal_rank: {"first_relevant": 3}} - match: {details.amsterdam_query.unknown_docs: [ {"_index": "foo", "_id": "doc2"}, {"_index": "foo", "_id": "doc3"} ]} - match: {details.berlin_query.quality_level: 0.5} - - match: {details.berlin_query.metric_details: {"first_relevant": 2}} + - match: {details.berlin_query.metric_details.mean_reciprocal_rank: {"first_relevant": 2}} - match: {details.berlin_query.unknown_docs: [ {"_index": "foo", "_id": "doc1"}]} diff --git a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/30_failures.yml b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/30_failures.yml index 24902253eb0..9cb9dbbbcf1 100644 --- a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/30_failures.yml +++ b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/30_failures.yml @@ -37,6 +37,6 @@ - match: { quality_level: 1} - match: { details.amsterdam_query.quality_level: 1.0} - match: { details.amsterdam_query.unknown_docs: [ ]} - - match: { details.amsterdam_query.metric_details: {"relevant_docs_retrieved": 1, "docs_retrieved": 1}} + - match: { details.amsterdam_query.metric_details.precision: {"relevant_docs_retrieved": 1, "docs_retrieved": 1}} - is_true: failures.invalid_query diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/ObjectParser.java b/server/src/main/java/org/elasticsearch/common/xcontent/ObjectParser.java index 8ba30178dc9..0b15baf1337 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/ObjectParser.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/ObjectParser.java @@ -298,6 +298,7 @@ public final class ObjectParser extends AbstractObjectParser extends AbstractObjectParser extends AbstractObjectParser fields = get(Fields.FIELDS, values, null); + Map fields = get(Fields.FIELDS, values, Collections.emptyMap()); SearchHit searchHit = new SearchHit(-1, id, type, nestedIdentity, fields); searchHit.index = get(Fields._INDEX, values, null); @@ -562,7 +563,6 @@ public final class SearchHit implements Streamable, ToXContentObject, Iterable