From e278c1d17dd6276fe2fc65495691077f6826ab4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Tue, 14 Nov 2017 19:26:32 +0100 Subject: [PATCH] Improving and cleaning up tests Removing the unnecessary RankEvalTestHelper, making use of the common test infra in ESTestCase, also hardening a few of the classes by making more fields final. --- .../rankeval/DiscountedCumulativeGain.java | 48 +++----- .../index/rankeval/DocumentKey.java | 106 ---------------- .../index/rankeval/EvalQueryQuality.java | 8 +- ...alityMetric.java => EvaluationMetric.java} | 23 ++-- .../index/rankeval/MeanReciprocalRank.java | 56 +++++---- .../index/rankeval/PrecisionAtK.java | 69 +++++------ .../index/rankeval/RankEvalPlugin.java | 21 ++-- .../index/rankeval/RankEvalSpec.java | 24 ++-- .../index/rankeval/RatedDocument.java | 101 +++++++++++---- .../index/rankeval/RatedRequest.java | 1 + .../rankeval/TransportRankEvalAction.java | 4 +- .../DiscountedCumulativeGainTests.java | 115 ++++++++++-------- .../index/rankeval/DocumentKeyTests.java | 67 ---------- .../index/rankeval/EvalQueryQualityTests.java | 19 +-- ...ests.java => MeanReciprocalRankTests.java} | 60 +++++---- ...isionTests.java => PrecisionAtKTests.java} | 90 +++++++------- .../index/rankeval/RankEvalRequestIT.java | 12 +- .../index/rankeval/RankEvalResponseTests.java | 12 +- .../index/rankeval/RankEvalSpecTests.java | 51 ++++---- .../index/rankeval/RankEvalTestHelper.java | 94 -------------- .../index/rankeval/RatedDocumentTests.java | 25 ++-- .../index/rankeval/RatedRequestsTests.java | 40 +++--- .../index/rankeval/RatedSearchHitTests.java | 14 ++- .../index/rankeval/TestRatingEnum.java | 24 ++++ .../test/rank_eval/30_failures.yml | 2 +- 25 files changed, 450 insertions(+), 636 deletions(-) delete mode 100644 modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DocumentKey.java rename modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/{RankedListQualityMetric.java => EvaluationMetric.java} (83%) delete mode 100644 modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DocumentKeyTests.java rename modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/{ReciprocalRankTests.java => MeanReciprocalRankTests.java} (79%) rename modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/{PrecisionTests.java => PrecisionAtKTests.java} (74%) delete mode 100644 modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalTestHelper.java create mode 100644 modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/TestRatingEnum.java diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java index 6f843f92461..a544ffcb4ea 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java @@ -22,7 +22,7 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchHit; @@ -35,22 +35,25 @@ import java.util.List; import java.util.Objects; import java.util.stream.Collectors; -import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings; -public class DiscountedCumulativeGain implements RankedListQualityMetric { +public class DiscountedCumulativeGain implements EvaluationMetric { /** If set to true, the dcg will be normalized (ndcg) */ - private boolean normalize; + private final boolean normalize; + /** * If set to, this will be the rating for docs the user hasn't supplied an * explicit rating for */ - private Integer unknownDocRating; + private final Integer unknownDocRating; public static final String NAME = "dcg"; private static final double LOG2 = Math.log(2.0); public DiscountedCumulativeGain() { + this(false, null); } /** @@ -82,13 +85,6 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric { return NAME; } - /** - * If set to true, the dcg will be normalized (ndcg) - */ - public void setNormalize(boolean normalize) { - this.normalize = normalize; - } - /** * check whether this metric computes only dcg or "normalized" ndcg */ @@ -96,13 +92,6 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric { return this.normalize; } - /** - * the rating for docs the user hasn't supplied an explicit rating for - */ - public void setUnknownDocRating(int unknownDocRating) { - this.unknownDocRating = unknownDocRating; - } - /** * get the rating used for unrated documents */ @@ -118,10 +107,10 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric { List ratedHits = joinHitsWithRatings(hits, ratedDocs); List ratingsInSearchHits = new ArrayList<>(ratedHits.size()); for (RatedSearchHit hit : ratedHits) { - // unknownDocRating might be null, which means it will be unrated - // docs are ignored in the dcg calculation - // we still need to add them as a placeholder so the rank of the - // subsequent ratings is correct + // unknownDocRating might be null, which means it will be unrated docs are + // ignored in the dcg calculation + // we still need to add them as a placeholder so the rank of the subsequent + // ratings is correct ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating)); } double dcg = computeDCG(ratingsInSearchHits); @@ -151,12 +140,15 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric { private static final ParseField NORMALIZE_FIELD = new ParseField("normalize"); private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating"); - private static final ObjectParser PARSER = new ObjectParser<>( - "dcg_at", () -> new DiscountedCumulativeGain()); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("dcg_at", + args -> { + Boolean normalized = (Boolean) args[0]; + return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1]); + }); static { - PARSER.declareBoolean(DiscountedCumulativeGain::setNormalize, NORMALIZE_FIELD); - PARSER.declareInt(DiscountedCumulativeGain::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD); + PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD); + PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD); } public static DiscountedCumulativeGain fromXContent(XContentParser parser) { @@ -193,6 +185,4 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric { public final int hashCode() { return Objects.hash(normalize, unknownDocRating); } - - // TODO maybe also add debugging breakdown here } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DocumentKey.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DocumentKey.java deleted file mode 100644 index 7bbda60f18c..00000000000 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DocumentKey.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.index.rankeval; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ToXContentObject; -import org.elasticsearch.common.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.Objects; - -public class DocumentKey implements Writeable, ToXContentObject { - - private String docId; - private String index; - - void setIndex(String index) { - this.index = index; - } - - void setDocId(String docId) { - this.docId = docId; - } - - public DocumentKey(String index, String docId) { - if (Strings.isNullOrEmpty(index)) { - throw new IllegalArgumentException("Index must be set for each rated document"); - } - if (Strings.isNullOrEmpty(docId)) { - throw new IllegalArgumentException("DocId must be set for each rated document"); - } - - this.index = index; - this.docId = docId; - } - - public DocumentKey(StreamInput in) throws IOException { - this.index = in.readString(); - this.docId = in.readString(); - } - - public String getIndex() { - return index; - } - - public String getDocID() { - return docId; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(index); - out.writeString(docId); - } - - @Override - public final boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - DocumentKey other = (DocumentKey) obj; - return Objects.equals(index, other.index) && Objects.equals(docId, other.docId); - } - - @Override - public final int hashCode() { - return Objects.hash(index, docId); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(RatedDocument.INDEX_FIELD.getPreferredName(), index); - builder.field(RatedDocument.DOC_ID_FIELD.getPreferredName(), docId); - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } -} diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvalQueryQuality.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvalQueryQuality.java index 6cd7cbac52a..e7e00c1699d 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvalQueryQuality.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvalQueryQuality.java @@ -24,6 +24,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey; import java.io.IOException; import java.util.ArrayList; @@ -91,8 +92,11 @@ public class EvalQueryQuality implements ToXContent, Writeable { builder.startObject(id); builder.field("quality_level", this.qualityLevel); builder.startArray("unknown_docs"); - for (DocumentKey key : RankedListQualityMetric.filterUnknownDocuments(hits)) { - key.toXContent(builder, params); + for (DocumentKey key : EvaluationMetric.filterUnknownDocuments(hits)) { + builder.startObject(); + builder.field(RatedDocument.INDEX_FIELD.getPreferredName(), key.getIndex()); + builder.field(RatedDocument.DOC_ID_FIELD.getPreferredName(), key.getDocId()); + builder.endObject(); } builder.endArray(); builder.startArray("hits"); diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvaluationMetric.java similarity index 83% rename from modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java rename to modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvaluationMetric.java index 35fa49fa7fa..6754d039266 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/EvaluationMetric.java @@ -24,7 +24,9 @@ import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser.Token; +import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; import java.io.IOException; import java.util.ArrayList; @@ -35,13 +37,14 @@ import java.util.Optional; import java.util.stream.Collectors; /** - * Classes implementing this interface provide a means to compute the quality of a result list returned by some search. + * Implementations of {@link EvaluationMetric} need to provide a way to compute the quality metric for + * a result list returned by some search (@link {@link SearchHits}) and a list of rated documents. */ -public interface RankedListQualityMetric extends ToXContent, NamedWriteable { +public interface EvaluationMetric extends ToXContent, NamedWriteable { /** * Returns a single metric representing the ranking quality of a set of returned - * documents wrt. to a set of document Ids labeled as relevant for this search. + * documents wrt. to a set of document ids labeled as relevant for this search. * * @param taskId * the id of the query for which the ranking is currently evaluated @@ -55,15 +58,15 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable { */ EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List ratedDocs); - static RankedListQualityMetric fromXContent(XContentParser parser) throws IOException { - RankedListQualityMetric rc; + static EvaluationMetric fromXContent(XContentParser parser) throws IOException { + EvaluationMetric rc; Token token = parser.nextToken(); if (token != XContentParser.Token.FIELD_NAME) { throw new ParsingException(parser.getTokenLocation(), "[_na] missing required metric name"); } String metricName = parser.currentName(); - // TODO maybe switch to using a plugable registry later? + // TODO switch to using a plugable registry switch (metricName) { case PrecisionAtK.NAME: rc = PrecisionAtK.fromXContent(parser); @@ -101,13 +104,19 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable { return ratedSearchHits; } + /** + * filter @link {@link RatedSearchHit} that don't have a rating + */ static List filterUnknownDocuments(List ratedHits) { - // join hits with rated documents List unknownDocs = ratedHits.stream().filter(hit -> hit.getRating().isPresent() == false) .map(hit -> new DocumentKey(hit.getSearchHit().getIndex(), hit.getSearchHit().getId())).collect(Collectors.toList()); return unknownDocs; } + /** + * how evaluation metrics for particular search queries get combined for the overall evaluation score. + * Defaults to averaging over the partial results. + */ default double combine(Collection partialResults) { return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size(); } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java index c8a84a03b39..2b47f68587e 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java @@ -22,7 +22,7 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchHit; @@ -34,49 +34,45 @@ import java.util.Optional; import javax.naming.directory.SearchResult; -import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings; /** - * Evaluate mean reciprocal rank. By default documents with a rating equal or bigger - * than 1 are considered to be "relevant" for the reciprocal rank calculation. - * This value can be changes using the "relevant_rating_threshold" parameter. + * Evaluates using mean reciprocal rank. By default documents with a rating + * equal or bigger than 1 are considered to be "relevant" for the reciprocal + * rank calculation. This value can be changes using the + * "relevant_rating_threshold" parameter. */ -public class MeanReciprocalRank implements RankedListQualityMetric { +public class MeanReciprocalRank implements EvaluationMetric { + + private static final int DEFAULT_RATING_THRESHOLD = 1; public static final String NAME = "mean_reciprocal_rank"; /** ratings equal or above this value will be considered relevant. */ - private int relevantRatingThreshhold = 1; + private final int relevantRatingThreshhold; - /** - * Initializes maxAcceptableRank with 10 - */ public MeanReciprocalRank() { - // use defaults + this(DEFAULT_RATING_THRESHOLD); } public MeanReciprocalRank(StreamInput in) throws IOException { this.relevantRatingThreshhold = in.readVInt(); } - @Override - public String getWriteableName() { - return NAME; - } - - /** - * Sets the rating threshold above which ratings are considered to be - * "relevant" for this metric. - */ - public void setRelevantRatingThreshhold(int threshold) { + public MeanReciprocalRank(int threshold) { if (threshold < 0) { throw new IllegalArgumentException( "Relevant rating threshold for precision must be positive integer."); } - this.relevantRatingThreshhold = threshold; } + @Override + public String getWriteableName() { + return NAME; + } + /** * Return the rating threshold above which ratings are considered to be * "relevant" for this metric. Defaults to 1. @@ -119,13 +115,19 @@ public class MeanReciprocalRank implements RankedListQualityMetric { out.writeVInt(relevantRatingThreshhold); } - private static final ParseField RELEVANT_RATING_FIELD = new ParseField( - "relevant_rating_threshold"); - private static final ObjectParser PARSER = new ObjectParser<>( - "reciprocal_rank", () -> new MeanReciprocalRank()); + private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold"); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("reciprocal_rank", + args -> { + Integer optionalThreshold = (Integer) args[0]; + if (optionalThreshold == null) { + return new MeanReciprocalRank(); + } else { + return new MeanReciprocalRank(optionalThreshold); + } + }); static { - PARSER.declareInt(MeanReciprocalRank::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD); + PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD); } public static MeanReciprocalRank fromXContent(XContentParser parser) { diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java index ea31573f8b3..76cc72bc82d 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java @@ -22,7 +22,7 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchHit; @@ -34,7 +34,8 @@ import java.util.Optional; import javax.naming.directory.SearchResult; -import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings; /** * Evaluate Precision of the search results. Documents without a rating are @@ -42,15 +43,12 @@ import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsW * considered to be "relevant" for the precision calculation. This value can be * changes using the "relevant_rating_threshold" parameter. */ -public class PrecisionAtK implements RankedListQualityMetric { +public class PrecisionAtK implements EvaluationMetric { public static final String NAME = "precision"; - private static final ParseField RELEVANT_RATING_FIELD = new ParseField( - "relevant_rating_threshold"); + private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold"); private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled"); - private static final ObjectParser PARSER = new ObjectParser<>(NAME, - PrecisionAtK::new); /** * This setting controls how unlabeled documents in the search hits are @@ -58,29 +56,47 @@ public class PrecisionAtK implements RankedListQualityMetric { * as true or false positives. Set to 'false', they are treated as false * positives. */ - private boolean ignoreUnlabeled = false; + private final boolean ignoreUnlabeled; /** ratings equal or above this value will be considered relevant. */ - private int relevantRatingThreshhold = 1; + private final int relevantRatingThreshhold; + + public PrecisionAtK(int threshold, boolean ignoreUnlabeled) { + if (threshold < 0) { + throw new IllegalArgumentException( + "Relevant rating threshold for precision must be positive integer."); + } + this.relevantRatingThreshhold = threshold; + this.ignoreUnlabeled = ignoreUnlabeled; + } public PrecisionAtK() { - // needed for supplier in parser + this(1, false); } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + args -> { + Integer threshHold = (Integer) args[0]; + Boolean ignoreUnlabeled = (Boolean) args[1]; + return new PrecisionAtK(threshHold == null ? 1 : threshHold, + ignoreUnlabeled == null ? false : ignoreUnlabeled); + }); + static { - PARSER.declareInt(PrecisionAtK::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD); - PARSER.declareBoolean(PrecisionAtK::setIgnoreUnlabeled, IGNORE_UNLABELED_FIELD); + PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD); + PARSER.declareBoolean(optionalConstructorArg(), IGNORE_UNLABELED_FIELD); } - public PrecisionAtK(StreamInput in) throws IOException { - relevantRatingThreshhold = in.readOptionalVInt(); - ignoreUnlabeled = in.readOptionalBoolean(); + PrecisionAtK(StreamInput in) throws IOException { + relevantRatingThreshhold = in.readVInt(); + ignoreUnlabeled = in.readBoolean(); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalVInt(relevantRatingThreshhold); - out.writeOptionalBoolean(ignoreUnlabeled); + out.writeVInt(relevantRatingThreshhold); + out.writeBoolean(ignoreUnlabeled); } @Override @@ -88,18 +104,6 @@ public class PrecisionAtK implements RankedListQualityMetric { return NAME; } - /** - * Sets the rating threshold above which ratings are considered to be - * "relevant" for this metric. - */ - public void setRelevantRatingThreshhold(int threshold) { - if (threshold < 0) { - throw new IllegalArgumentException( - "Relevant rating threshold for precision must be positive integer."); - } - this.relevantRatingThreshhold = threshold; - } - /** * Return the rating threshold above which ratings are considered to be * "relevant" for this metric. Defaults to 1. @@ -108,13 +112,6 @@ public class PrecisionAtK implements RankedListQualityMetric { return relevantRatingThreshhold; } - /** - * Sets the 'ìgnore_unlabeled' parameter - */ - public void setIgnoreUnlabeled(boolean ignoreUnlabeled) { - this.ignoreUnlabeled = ignoreUnlabeled; - } - /** * Gets the 'ìgnore_unlabeled' parameter */ diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java index 69825405363..189f7ab91ac 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java @@ -56,24 +56,19 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin { } /** - * Returns parsers for {@link NamedWriteable} this plugin will use over the - * transport protocol. - * + * Returns parsers for {@link NamedWriteable} objects that this plugin sends over the transport protocol. * @see NamedWriteableRegistry */ @Override public List getNamedWriteables() { List namedWriteables = new ArrayList<>(); - namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, - PrecisionAtK.NAME, PrecisionAtK::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, - MeanReciprocalRank.NAME, MeanReciprocalRank::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, - DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, PrecisionAtK.NAME, - PrecisionAtK.Breakdown::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, - MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, PrecisionAtK.NAME, PrecisionAtK.Breakdown::new)); + namedWriteables + .add(new NamedWriteableRegistry.Entry(MetricDetails.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new)); return namedWriteables; } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java index 60d0a5dd4fd..3a88bc4de57 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java @@ -34,7 +34,9 @@ import org.elasticsearch.script.Script; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Objects; @@ -53,9 +55,9 @@ public class RankEvalSpec implements Writeable, ToXContentObject { * Collection of query specifications, that is e.g. search request templates * to use for query translation. */ - private Collection ratedRequests = new ArrayList<>(); + private final List ratedRequests; /** Definition of the quality metric, e.g. precision at N */ - private RankedListQualityMetric metric; + private final EvaluationMetric metric; /** Maximum number of requests to execute in parallel. */ private int maxConcurrentSearches = MAX_CONCURRENT_SEARCHES; /** Default max number of requests. */ @@ -63,7 +65,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject { /** optional: Templates to base test requests on */ private Map templates = new HashMap<>(); - public RankEvalSpec(Collection ratedRequests, RankedListQualityMetric metric, + public RankEvalSpec(List ratedRequests, EvaluationMetric metric, Collection templates) { if (ratedRequests == null || ratedRequests.size() < 1) { throw new IllegalStateException( @@ -92,7 +94,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject { } } - public RankEvalSpec(Collection ratedRequests, RankedListQualityMetric metric) { + public RankEvalSpec(List ratedRequests, EvaluationMetric metric) { this(ratedRequests, metric, null); } @@ -102,7 +104,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject { for (int i = 0; i < specSize; i++) { ratedRequests.add(new RatedRequest(in)); } - metric = in.readNamedWriteable(RankedListQualityMetric.class); + metric = in.readNamedWriteable(EvaluationMetric.class); int size = in.readVInt(); for (int i = 0; i < size; i++) { String key = in.readString(); @@ -128,13 +130,13 @@ public class RankEvalSpec implements Writeable, ToXContentObject { } /** Returns the metric to use for quality evaluation.*/ - public RankedListQualityMetric getMetric() { + public EvaluationMetric getMetric() { return metric; } /** Returns a list of intent to query translation specifications to evaluate. */ - public Collection getRatedRequests() { - return ratedRequests; + public List getRatedRequests() { + return Collections.unmodifiableList(ratedRequests); } /** Returns the template to base test requests on. */ @@ -160,8 +162,8 @@ public class RankEvalSpec implements Writeable, ToXContentObject { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("rank_eval", - a -> new RankEvalSpec((Collection) a[0], - (RankedListQualityMetric) a[1], (Collection) a[2])); + a -> new RankEvalSpec((List) a[0], + (EvaluationMetric) a[1], (Collection) a[2])); static { PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> { @@ -169,7 +171,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject { } , REQUESTS_FIELD); PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> { try { - return RankedListQualityMetric.fromXContent(p); + return EvaluationMetric.fromXContent(p); } catch (IOException ex) { throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex); } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocument.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocument.java index ec19aa8aa0e..fc5880e60cd 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocument.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocument.java @@ -33,13 +33,24 @@ import java.io.IOException; import java.util.Objects; /** - * A document ID and its rating for the query QA use case. + * Represents a document (specified by its _index/_id) and its corresponding rating + * with respect to a specific search query. + *

+ * Json structure in a request: + *

+ * {
+ *   "_index": "my_index",
+ *   "_id": "doc1",
+ *   "rating": 0
+ * }
+ * 
+ * */ public class RatedDocument implements Writeable, ToXContentObject { - public static final ParseField RATING_FIELD = new ParseField("rating"); - public static final ParseField DOC_ID_FIELD = new ParseField("_id"); - public static final ParseField INDEX_FIELD = new ParseField("_index"); + static final ParseField RATING_FIELD = new ParseField("rating"); + static final ParseField DOC_ID_FIELD = new ParseField("_id"); + static final ParseField INDEX_FIELD = new ParseField("_index"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("rated_document", a -> new RatedDocument((String) a[0], (String) a[1], (Integer) a[2])); @@ -50,23 +61,19 @@ public class RatedDocument implements Writeable, ToXContentObject { PARSER.declareInt(ConstructingObjectParser.constructorArg(), RATING_FIELD); } - private int rating; - private DocumentKey key; + private final int rating; + private final DocumentKey key; - public RatedDocument(String index, String docId, int rating) { - this(new DocumentKey(index, docId), rating); - } - - public RatedDocument(StreamInput in) throws IOException { - this.key = new DocumentKey(in); - this.rating = in.readVInt(); - } - - public RatedDocument(DocumentKey ratedDocumentKey, int rating) { - this.key = ratedDocumentKey; + public RatedDocument(String index, String id, int rating) { + this.key = new DocumentKey(index, id); this.rating = rating; } + RatedDocument(StreamInput in) throws IOException { + this.key = new DocumentKey(in.readString(), in.readString()); + this.rating = in.readVInt(); + } + public DocumentKey getKey() { return this.key; } @@ -76,7 +83,7 @@ public class RatedDocument implements Writeable, ToXContentObject { } public String getDocID() { - return key.getDocID(); + return key.getDocId(); } public int getRating() { @@ -85,11 +92,12 @@ public class RatedDocument implements Writeable, ToXContentObject { @Override public void writeTo(StreamOutput out) throws IOException { - this.key.writeTo(out); + out.writeString(key.getIndex()); + out.writeString(key.getDocId()); out.writeVInt(rating); } - public static RatedDocument fromXContent(XContentParser parser) { + static RatedDocument fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } @@ -97,7 +105,7 @@ public class RatedDocument implements Writeable, ToXContentObject { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(INDEX_FIELD.getPreferredName(), key.getIndex()); - builder.field(DOC_ID_FIELD.getPreferredName(), key.getDocID()); + builder.field(DOC_ID_FIELD.getPreferredName(), key.getDocId()); builder.field(RATING_FIELD.getPreferredName(), rating); builder.endObject(); return builder; @@ -124,4 +132,55 @@ public class RatedDocument implements Writeable, ToXContentObject { public final int hashCode() { return Objects.hash(key, rating); } + + /** + * a joint document key consisting of the documents index and id + */ + static class DocumentKey { + + private final String docId; + private final String index; + + DocumentKey(String index, String docId) { + if (Strings.isNullOrEmpty(index)) { + throw new IllegalArgumentException("Index must be set for each rated document"); + } + if (Strings.isNullOrEmpty(docId)) { + throw new IllegalArgumentException("DocId must be set for each rated document"); + } + + this.index = index; + this.docId = docId; + } + + String getIndex() { + return index; + } + + String getDocId() { + return docId; + } + + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + DocumentKey other = (DocumentKey) obj; + return Objects.equals(index, other.index) && Objects.equals(docId, other.docId); + } + + @Override + public final int hashCode() { + return Objects.hash(index, docId); + } + + @Override + public String toString() { + return "{\"_index\":\"" + index + "\",\"_id\":\"" + docId + "\"}"; + } + } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java index af322add5ec..e8ed925987f 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey; import org.elasticsearch.search.builder.SearchSourceBuilder; import java.io.IOException; diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java index 8c9024d0500..4e3fa8afa46 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java @@ -157,11 +157,11 @@ public class TransportRankEvalAction private RatedRequest specification; private Map requestDetails; private Map errors; - private RankedListQualityMetric metric; + private EvaluationMetric metric; private AtomicInteger responseCounter; public RankEvalActionListener(ActionListener listener, - RankedListQualityMetric metric, RatedRequest specification, + EvaluationMetric metric, RatedRequest specification, Map details, Map errors, AtomicInteger responseCounter) { this.listener = listener; diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java index 4154000a7b0..b6bd80f497e 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.text.Text; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -36,20 +37,23 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import static org.elasticsearch.index.rankeval.RankedListQualityMetric.filterUnknownDocuments; +import static org.elasticsearch.index.rankeval.EvaluationMetric.filterUnknownDocuments; +import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; + public class DiscountedCumulativeGainTests extends ESTestCase { /** * Assuming the docs are ranked in the following order: * - * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / - * log_2(rank + 1) + * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) * ------------------------------------------------------------------------------------------- - * 1 | 3 | 7.0 | 1.0 | 7.0 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 - * 3 | 3 | 7.0 | 2.0 | 3.5 4 | 0 | 0.0 | 2.321928094887362 | 0.0 5 | 1 | 1.0 - * | 2.584962500721156 | 0.38685280723454163 6 | 2 | 3.0 | 2.807354922057604 - * | 1.0686215613240666 + * 1 | 3 | 7.0 | 1.0 | 7.0 2 |  + * 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 + * 3 | 3 | 7.0 | 2.0 | 3.5 + * 4 | 0 | 0.0 | 2.321928094887362 | 0.0 + * 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163 + * 6 | 2 | 3.0 | 2.807354922057604 | 1.0686215613240666 * * dcg = 13.84826362927298 (sum of last column) */ @@ -69,17 +73,18 @@ public class DiscountedCumulativeGainTests extends ESTestCase { * Check with normalization: to get the maximal possible dcg, sort documents by * relevance in descending order * - * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / - * log_2(rank + 1) + * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) * --------------------------------------------------------------------------------------- - * 1 | 3 | 7.0 | 1.0  | 7.0 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202 - * 3 | 2 | 3.0 | 2.0  | 1.5 4 | 2 | 3.0 | 2.321928094887362  - * | 1.2920296742201793 5 | 1 | 1.0 | 2.584962500721156  | 0.38685280723454163 6 - * | 0 | 0.0 | 2.807354922057604  | 0.0 + * 1 | 3 | 7.0 | 1.0  | 7.0 + * 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202 + * 3 | 2 | 3.0 | 2.0  | 1.5 + * 4 | 2 | 3.0 | 2.321928094887362 | 1.2920296742201793 + * 5 | 1 | 1.0 | 2.584962500721156  | 0.38685280723454163 + * 6 | 0 | 0.0 | 2.807354922057604  | 0.0 * * idcg = 14.595390756454922 (sum of last column) */ - dcg.setNormalize(true); + dcg = new DiscountedCumulativeGain(true, null); assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001); } @@ -87,12 +92,14 @@ public class DiscountedCumulativeGainTests extends ESTestCase { * This tests metric when some documents in the search result don't have a * rating provided by the user. * - * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / - * log_2(rank + 1) + * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) * ------------------------------------------------------------------------------------------- - * 1 | 3 | 7.0 | 1.0 | 7.0 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 - * 3 | 3 | 7.0 | 2.0 | 3.5 4 | n/a | n/a | n/a | n/a 5 | 1 | 1.0 - * | 2.584962500721156 | 0.38685280723454163 6 | n/a | n/a | n/a | n/a + * 1 | 3 | 7.0 | 1.0 | 7.0 2 |  + * 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 + * 3 | 3 | 7.0 | 2.0 | 3.5 + * 4 | n/a | n/a | n/a | n/a + * 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163 + * 6 | n/a | n/a | n/a | n/a * * dcg = 12.779642067948913 (sum of last column) */ @@ -118,16 +125,18 @@ public class DiscountedCumulativeGainTests extends ESTestCase { * Check with normalization: to get the maximal possible dcg, sort documents by * relevance in descending order * - * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / - * log_2(rank + 1) + * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) * ---------------------------------------------------------------------------------------- - * 1 | 3 | 7.0 | 1.0  | 7.0 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202 - * 3 | 2 | 3.0 | 2.0  | 1.5 4 | 1 | 1.0 | 2.321928094887362   | 0.43067655807339 - * 5 | n.a | n.a | n.a.  | n.a. 6 | n.a | n.a | n.a  | n.a + * 1 | 3 | 7.0 | 1.0  | 7.0 + * 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202 + * 3 | 2 | 3.0 | 2.0  | 1.5 + * 4 | 1 | 1.0 | 2.321928094887362   | 0.43067655807339 + * 5 | n.a | n.a | n.a.  | n.a. + * 6 | n.a | n.a | n.a  | n.a * * idcg = 13.347184833073591 (sum of last column) */ - dcg.setNormalize(true); + dcg = new DiscountedCumulativeGain(true, null); assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001); } @@ -136,13 +145,15 @@ public class DiscountedCumulativeGainTests extends ESTestCase { * documents than search hits because we restrict DCG to be calculated at the * fourth position * - * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / - * log_2(rank + 1) + * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) * ------------------------------------------------------------------------------------------- - * 1 | 3 | 7.0 | 1.0 | 7.0 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 - * 3 | 3 | 7.0 | 2.0 | 3.5 4 | n/a | n/a | n/a | n/a - * ----------------------------------------------------------------- 5 | 1 | 1.0 - * | 2.584962500721156 | 0.38685280723454163 6 | n/a | n/a | n/a | n/a + * 1 | 3 | 7.0 | 1.0 | 7.0 2 |  + * 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 + * 3 | 3 | 7.0 | 2.0 | 3.5 + * 4 | n/a | n/a | n/a | n/a + * ----------------------------------------------------------------- + * 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163 + * 6 | n/a | n/a | n/a | n/a * * dcg = 12.392789260714371 (sum of last column until position 4) */ @@ -171,22 +182,24 @@ public class DiscountedCumulativeGainTests extends ESTestCase { * Check with normalization: to get the maximal possible dcg, sort documents by * relevance in descending order * - * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / - * log_2(rank + 1) + * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) * --------------------------------------------------------------------------------------- - * 1 | 3 | 7.0 | 1.0  | 7.0 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202 - * 3 | 2 | 3.0 | 2.0  | 1.5 4 | 1 | 1.0 | 2.321928094887362   | 0.43067655807339 + * 1 | 3 | 7.0 | 1.0  | 7.0 + * 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202 + * 3 | 2 | 3.0 | 2.0  | 1.5 + * 4 | 1 | 1.0 | 2.321928094887362   | 0.43067655807339 * --------------------------------------------------------------------------------------- - * 5 | n.a | n.a | n.a.  | n.a. 6 | n.a | n.a | n.a  | n.a + * 5 | n.a | n.a | n.a.  | n.a. + * 6 | n.a | n.a | n.a  | n.a * * idcg = 13.347184833073591 (sum of last column) */ - dcg.setNormalize(true); + dcg = new DiscountedCumulativeGain(true, null); assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001); } public void testParseFromXContent() throws IOException { - String xContent = " {\n" + " \"unknown_doc_rating\": 2,\n" + " \"normalize\": true\n" + "}"; + String xContent = " { \"unknown_doc_rating\": 2, \"normalize\": true }"; try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser); assertEquals(2, dcgAt.getUnknownDocRating().intValue()); @@ -217,29 +230,25 @@ public class DiscountedCumulativeGainTests extends ESTestCase { public void testSerialization() throws IOException { DiscountedCumulativeGain original = createTestItem(); - DiscountedCumulativeGain deserialized = RankEvalTestHelper.copy(original, DiscountedCumulativeGain::new); + DiscountedCumulativeGain deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()), + DiscountedCumulativeGain::new); assertEquals(deserialized, original); assertEquals(deserialized.hashCode(), original.hashCode()); assertNotSame(deserialized, original); } public void testEqualsAndHash() throws IOException { - DiscountedCumulativeGain testItem = createTestItem(); - RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), - RankEvalTestHelper.copy(testItem, DiscountedCumulativeGain::new)); + checkEqualsAndHashCode(createTestItem(), original -> { + return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating()); + }, DiscountedCumulativeGainTests::mutateTestItem); } private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) { - boolean normalise = original.getNormalize(); - int unknownDocRating = original.getUnknownDocRating(); - DiscountedCumulativeGain gain = new DiscountedCumulativeGain(); - gain.setNormalize(normalise); - gain.setUnknownDocRating(unknownDocRating); - - List mutators = new ArrayList<>(); - mutators.add(() -> gain.setNormalize(!original.getNormalize())); - mutators.add(() -> gain.setUnknownDocRating(randomValueOtherThan(unknownDocRating, () -> randomIntBetween(0, 10)))); - randomFrom(mutators).run(); - return gain; + if (randomBoolean()) { + return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating()); + } else { + return new DiscountedCumulativeGain(original.getNormalize(), + randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10))); + } } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DocumentKeyTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DocumentKeyTests.java deleted file mode 100644 index 7a241451bc9..00000000000 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DocumentKeyTests.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.index.rankeval; - -import org.elasticsearch.test.ESTestCase; - -import java.io.IOException; - -public class DocumentKeyTests extends ESTestCase { - - static DocumentKey createRandomRatedDocumentKey() { - String index = randomAlphaOfLengthBetween(1, 10); - String docId = randomAlphaOfLengthBetween(1, 10); - return new DocumentKey(index, docId); - } - - public DocumentKey createTestItem() { - return createRandomRatedDocumentKey(); - } - - public DocumentKey mutateTestItem(DocumentKey original) { - String index = original.getIndex(); - String docId = original.getDocID(); - switch (randomIntBetween(0, 1)) { - case 0: - index = index + "_"; - break; - case 1: - docId = docId + "_"; - break; - default: - throw new IllegalStateException("The test should only allow two parameters mutated"); - } - return new DocumentKey(index, docId); - } - - public void testEqualsAndHash() throws IOException { - DocumentKey testItem = createRandomRatedDocumentKey(); - RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), - new DocumentKey(testItem.getIndex(), testItem.getDocID())); - } - - public void testSerialization() throws IOException { - DocumentKey original = createTestItem(); - DocumentKey deserialized = RankEvalTestHelper.copy(original, DocumentKey::new); - assertEquals(deserialized, original); - assertEquals(deserialized.hashCode(), original.hashCode()); - assertNotSame(deserialized, original); - } -} diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java index 97e40f12f30..fb1c7db554a 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java @@ -20,22 +20,24 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey; import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; + public class EvalQueryQualityTests extends ESTestCase { - private static NamedWriteableRegistry namedWritableRegistry = new NamedWriteableRegistry( - new RankEvalPlugin().getNamedWriteables()); + private static NamedWriteableRegistry namedWritableRegistry = new NamedWriteableRegistry(new RankEvalPlugin().getNamedWriteables()); public static EvalQueryQuality randomEvalQueryQuality() { List unknownDocs = new ArrayList<>(); int numberOfUnknownDocs = randomInt(5); for (int i = 0; i < numberOfUnknownDocs; i++) { - unknownDocs.add(DocumentKeyTests.createRandomRatedDocumentKey()); + unknownDocs.add(new DocumentKey(randomAlphaOfLength(10), randomAlphaOfLength(10))); } int numberOfSearchHits = randomInt(5); List ratedHits = new ArrayList<>(); @@ -54,17 +56,18 @@ public class EvalQueryQualityTests extends ESTestCase { public void testSerialization() throws IOException { EvalQueryQuality original = randomEvalQueryQuality(); - EvalQueryQuality deserialized = RankEvalTestHelper.copy(original, EvalQueryQuality::new, - namedWritableRegistry); + EvalQueryQuality deserialized = copy(original); assertEquals(deserialized, original); assertEquals(deserialized.hashCode(), original.hashCode()); assertNotSame(deserialized, original); } + private static EvalQueryQuality copy(EvalQueryQuality original) throws IOException { + return ESTestCase.copyWriteable(original, namedWritableRegistry, EvalQueryQuality::new); + } + public void testEqualsAndHash() throws IOException { - EvalQueryQuality testItem = randomEvalQueryQuality(); - RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), - RankEvalTestHelper.copy(testItem, EvalQueryQuality::new, namedWritableRegistry)); + checkEqualsAndHashCode(randomEvalQueryQuality(), EvalQueryQualityTests::copy, EvalQueryQualityTests::mutateTestItem); } private static EvalQueryQuality mutateTestItem(EvalQueryQuality original) { diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java similarity index 79% rename from modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java rename to modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java index bed24a18505..81586783c06 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java @@ -19,14 +19,15 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.text.Text; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.Index; -import org.elasticsearch.index.rankeval.PrecisionTests.Rating; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.test.ESTestCase; @@ -38,21 +39,35 @@ import java.util.Collections; import java.util.List; import java.util.Vector; -public class ReciprocalRankTests extends ESTestCase { +import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; + +public class MeanReciprocalRankTests extends ESTestCase { + + public void testParseFromXContent() throws IOException { + String xContent = "{ }"; + try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { + MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser); + assertEquals(1, mrr.getRelevantRatingThreshold()); + } + + xContent = "{ \"relevant_rating_threshold\": 2 }"; + try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { + MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser); + assertEquals(2, mrr.getRelevantRatingThreshold()); + } + } public void testMaxAcceptableRank() { MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(); - int searchHits = randomIntBetween(1, 50); - SearchHit[] hits = createSearchHits(0, searchHits, "test"); List ratedDocs = new ArrayList<>(); int relevantAt = randomIntBetween(0, searchHits); for (int i = 0; i <= searchHits; i++) { if (i == relevantAt) { - ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.RELEVANT.ordinal())); + ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.RELEVANT.ordinal())); } else { - ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.IRRELEVANT.ordinal())); + ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.IRRELEVANT.ordinal())); } } @@ -76,9 +91,9 @@ public class ReciprocalRankTests extends ESTestCase { int relevantAt = randomIntBetween(0, 9); for (int i = 0; i <= 20; i++) { if (i == relevantAt) { - ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.RELEVANT.ordinal())); + ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.RELEVANT.ordinal())); } else { - ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.IRRELEVANT.ordinal())); + ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.IRRELEVANT.ordinal())); } } @@ -101,8 +116,7 @@ public class ReciprocalRankTests extends ESTestCase { rated.add(new RatedDocument("test", "4", 4)); SearchHit[] hits = createSearchHits(0, 5, "test"); - MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(); - reciprocalRank.setRelevantRatingThreshhold(2); + MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2); EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated); assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001); assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank()); @@ -153,35 +167,31 @@ public class ReciprocalRankTests extends ESTestCase { } private static MeanReciprocalRank createTestItem() { - MeanReciprocalRank testItem = new MeanReciprocalRank(); - testItem.setRelevantRatingThreshhold(randomIntBetween(0, 20)); - return testItem; + return new MeanReciprocalRank(randomIntBetween(0, 20)); } public void testSerialization() throws IOException { MeanReciprocalRank original = createTestItem(); - - MeanReciprocalRank deserialized = RankEvalTestHelper.copy(original, MeanReciprocalRank::new); + MeanReciprocalRank deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()), + MeanReciprocalRank::new); assertEquals(deserialized, original); assertEquals(deserialized.hashCode(), original.hashCode()); assertNotSame(deserialized, original); } public void testEqualsAndHash() throws IOException { - MeanReciprocalRank testItem = createTestItem(); - RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), - RankEvalTestHelper.copy(testItem, MeanReciprocalRank::new)); + checkEqualsAndHashCode(createTestItem(), MeanReciprocalRankTests::copy, MeanReciprocalRankTests::mutate); } - private static MeanReciprocalRank mutateTestItem(MeanReciprocalRank testItem) { - int relevantThreshold = testItem.getRelevantRatingThreshold(); - MeanReciprocalRank rank = new MeanReciprocalRank(); - rank.setRelevantRatingThreshhold(randomValueOtherThan(relevantThreshold, () -> randomIntBetween(0, 10))); - return rank; + private static MeanReciprocalRank copy(MeanReciprocalRank testItem) { + return new MeanReciprocalRank(testItem.getRelevantRatingThreshold()); + } + + private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) { + return new MeanReciprocalRank(randomValueOtherThan(testItem.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10))); } public void testInvalidRelevantThreshold() { - MeanReciprocalRank prez = new MeanReciprocalRank(); - expectThrows(IllegalArgumentException.class, () -> prez.setRelevantRatingThreshhold(-1)); + expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1)); } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtKTests.java similarity index 74% rename from modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionTests.java rename to modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtKTests.java index d520174c28e..01ea5e7cf65 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtKTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.text.Text; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -38,11 +39,13 @@ import java.util.Collections; import java.util.List; import java.util.Vector; -public class PrecisionTests extends ESTestCase { +import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; + +public class PrecisionAtKTests extends ESTestCase { public void testPrecisionAtFiveCalculation() { List rated = new ArrayList<>(); - rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal())); + rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal())); EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated); assertEquals(1, evaluated.getQualityLevel(), 0.00001); assertEquals(1, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved()); @@ -51,11 +54,11 @@ public class PrecisionTests extends ESTestCase { public void testPrecisionAtFiveIgnoreOneResult() { List rated = new ArrayList<>(); - rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal())); - rated.add(new RatedDocument("test", "1", Rating.RELEVANT.ordinal())); - rated.add(new RatedDocument("test", "2", Rating.RELEVANT.ordinal())); - rated.add(new RatedDocument("test", "3", Rating.RELEVANT.ordinal())); - rated.add(new RatedDocument("test", "4", Rating.IRRELEVANT.ordinal())); + rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal())); + rated.add(createRatedDoc("test", "1", TestRatingEnum.RELEVANT.ordinal())); + rated.add(createRatedDoc("test", "2", TestRatingEnum.RELEVANT.ordinal())); + rated.add(createRatedDoc("test", "3", TestRatingEnum.RELEVANT.ordinal())); + rated.add(createRatedDoc("test", "4", TestRatingEnum.IRRELEVANT.ordinal())); EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated); assertEquals((double) 4 / 5, evaluated.getQualityLevel(), 0.00001); assertEquals(4, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved()); @@ -69,13 +72,12 @@ public class PrecisionTests extends ESTestCase { */ public void testPrecisionAtFiveRelevanceThreshold() { List rated = new ArrayList<>(); - rated.add(new RatedDocument("test", "0", 0)); - rated.add(new RatedDocument("test", "1", 1)); - rated.add(new RatedDocument("test", "2", 2)); - rated.add(new RatedDocument("test", "3", 3)); - rated.add(new RatedDocument("test", "4", 4)); - PrecisionAtK precisionAtN = new PrecisionAtK(); - precisionAtN.setRelevantRatingThreshhold(2); + rated.add(createRatedDoc("test", "0", 0)); + rated.add(createRatedDoc("test", "1", 1)); + rated.add(createRatedDoc("test", "2", 2)); + rated.add(createRatedDoc("test", "3", 3)); + rated.add(createRatedDoc("test", "4", 4)); + PrecisionAtK precisionAtN = new PrecisionAtK(2, false); EvalQueryQuality evaluated = precisionAtN.evaluate("id", toSearchHits(rated, "test"), rated); assertEquals((double) 3 / 5, evaluated.getQualityLevel(), 0.00001); assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved()); @@ -84,11 +86,11 @@ public class PrecisionTests extends ESTestCase { public void testPrecisionAtFiveCorrectIndex() { List rated = new ArrayList<>(); - rated.add(new RatedDocument("test_other", "0", Rating.RELEVANT.ordinal())); - rated.add(new RatedDocument("test_other", "1", Rating.RELEVANT.ordinal())); - rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal())); - rated.add(new RatedDocument("test", "1", Rating.RELEVANT.ordinal())); - rated.add(new RatedDocument("test", "2", Rating.IRRELEVANT.ordinal())); + rated.add(createRatedDoc("test_other", "0", TestRatingEnum.RELEVANT.ordinal())); + rated.add(createRatedDoc("test_other", "1", TestRatingEnum.RELEVANT.ordinal())); + rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal())); + rated.add(createRatedDoc("test", "1", TestRatingEnum.RELEVANT.ordinal())); + rated.add(createRatedDoc("test", "2", TestRatingEnum.IRRELEVANT.ordinal())); // the following search hits contain only the last three documents EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated.subList(2, 5), "test"), rated); assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001); @@ -98,8 +100,8 @@ public class PrecisionTests extends ESTestCase { public void testIgnoreUnlabeled() { List rated = new ArrayList<>(); - rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal())); - rated.add(new RatedDocument("test", "1", Rating.RELEVANT.ordinal())); + rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal())); + rated.add(createRatedDoc("test", "1", TestRatingEnum.RELEVANT.ordinal())); // add an unlabeled search hit SearchHit[] searchHits = Arrays.copyOf(toSearchHits(rated, "test"), 3); searchHits[2] = new SearchHit(2, "2", new Text("testtype"), Collections.emptyMap()); @@ -111,8 +113,7 @@ public class PrecisionTests extends ESTestCase { assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved()); // also try with setting `ignore_unlabeled` - PrecisionAtK prec = new PrecisionAtK(); - prec.setIgnoreUnlabeled(true); + PrecisionAtK prec = new PrecisionAtK(1, true); evaluated = prec.evaluate("id", searchHits, rated); assertEquals((double) 2 / 2, evaluated.getQualityLevel(), 0.00001); assertEquals(2, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved()); @@ -131,8 +132,7 @@ public class PrecisionTests extends ESTestCase { assertEquals(5, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved()); // also try with setting `ignore_unlabeled` - PrecisionAtK prec = new PrecisionAtK(); - prec.setIgnoreUnlabeled(true); + PrecisionAtK prec = new PrecisionAtK(1, true); evaluated = prec.evaluate("id", hits, Collections.emptyList()); assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001); assertEquals(0, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved()); @@ -158,16 +158,11 @@ public class PrecisionTests extends ESTestCase { public void testInvalidRelevantThreshold() { PrecisionAtK prez = new PrecisionAtK(); - expectThrows(IllegalArgumentException.class, () -> prez.setRelevantRatingThreshhold(-1)); + expectThrows(IllegalArgumentException.class, () -> new PrecisionAtK(-1, false)); } public static PrecisionAtK createTestItem() { - PrecisionAtK precision = new PrecisionAtK(); - if (randomBoolean()) { - precision.setRelevantRatingThreshhold(randomIntBetween(0, 10)); - } - precision.setIgnoreUnlabeled(randomBoolean()); - return precision; + return new PrecisionAtK(randomIntBetween(0, 10), randomBoolean()); } public void testXContentRoundtrip() throws IOException { @@ -186,29 +181,28 @@ public class PrecisionTests extends ESTestCase { public void testSerialization() throws IOException { PrecisionAtK original = createTestItem(); - PrecisionAtK deserialized = RankEvalTestHelper.copy(original, PrecisionAtK::new); + PrecisionAtK deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()), + PrecisionAtK::new); assertEquals(deserialized, original); assertEquals(deserialized.hashCode(), original.hashCode()); assertNotSame(deserialized, original); } public void testEqualsAndHash() throws IOException { - PrecisionAtK testItem = createTestItem(); - RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), RankEvalTestHelper.copy(testItem, PrecisionAtK::new)); + checkEqualsAndHashCode(createTestItem(), PrecisionAtKTests::copy, PrecisionAtKTests::mutate); } - private static PrecisionAtK mutateTestItem(PrecisionAtK original) { - boolean ignoreUnlabeled = original.getIgnoreUnlabeled(); - int relevantThreshold = original.getRelevantRatingThreshold(); - PrecisionAtK precision = new PrecisionAtK(); - precision.setIgnoreUnlabeled(ignoreUnlabeled); - precision.setRelevantRatingThreshhold(relevantThreshold); + private static PrecisionAtK copy(PrecisionAtK original) { + return new PrecisionAtK(original.getRelevantRatingThreshold(), original.getIgnoreUnlabeled()); + } - List mutators = new ArrayList<>(); - mutators.add(() -> precision.setIgnoreUnlabeled(!ignoreUnlabeled)); - mutators.add(() -> precision.setRelevantRatingThreshhold(randomValueOtherThan(relevantThreshold, () -> randomIntBetween(0, 10)))); - randomFrom(mutators).run(); - return precision; + private static PrecisionAtK mutate(PrecisionAtK original) { + if (randomBoolean()) { + return new PrecisionAtK(original.getRelevantRatingThreshold(), !original.getIgnoreUnlabeled()); + } else { + return new PrecisionAtK(randomValueOtherThan(original.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)), + original.getIgnoreUnlabeled()); + } } private static SearchHit[] toSearchHits(List rated, String index) { @@ -220,7 +214,7 @@ public class PrecisionTests extends ESTestCase { return hits; } - public enum Rating { - IRRELEVANT, RELEVANT; + private static RatedDocument createRatedDoc(String index, String id, int rating) { + return new RatedDocument(index, id, rating); } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java index 2e3f86543c6..3b674d7e8a6 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java @@ -22,7 +22,6 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.index.rankeval.PrecisionTests.Rating; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.ESIntegTestCase; @@ -35,7 +34,7 @@ import java.util.List; import java.util.Map.Entry; import java.util.Set; -import static org.elasticsearch.index.rankeval.RankedListQualityMetric.filterUnknownDocuments; +import static org.elasticsearch.index.rankeval.EvaluationMetric.filterUnknownDocuments; public class RankEvalRequestIT extends ESIntegTestCase { @Override @@ -82,8 +81,7 @@ public class RankEvalRequestIT extends ESIntegTestCase { specifications.add(berlinRequest); - PrecisionAtK metric = new PrecisionAtK(); - metric.setIgnoreUnlabeled(true); + PrecisionAtK metric = new PrecisionAtK(1, true); RankEvalSpec task = new RankEvalSpec(specifications, metric); RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), @@ -106,7 +104,7 @@ public class RankEvalRequestIT extends ESIntegTestCase { if (id.equals("1") || id.equals("6")) { assertFalse(hit.getRating().isPresent()); } else { - assertEquals(Rating.RELEVANT.ordinal(), hit.getRating().get().intValue()); + assertEquals(TestRatingEnum.RELEVANT.ordinal(), hit.getRating().get().intValue()); } } } @@ -117,7 +115,7 @@ public class RankEvalRequestIT extends ESIntegTestCase { for (RatedSearchHit hit : hitsAndRatings) { String id = hit.getSearchHit().getId(); if (id.equals("1")) { - assertEquals(Rating.RELEVANT.ordinal(), hit.getRating().get().intValue()); + assertEquals(TestRatingEnum.RELEVANT.ordinal(), hit.getRating().get().intValue()); } else { assertFalse(hit.getRating().isPresent()); } @@ -167,7 +165,7 @@ public class RankEvalRequestIT extends ESIntegTestCase { private static List createRelevant(String... docs) { List relevant = new ArrayList<>(); for (String doc : docs) { - relevant.add(new RatedDocument("test", doc, Rating.RELEVANT.ordinal())); + relevant.add(new RatedDocument("test", doc, TestRatingEnum.RELEVANT.ordinal())); } return relevant; } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java index 497f17de1b9..aff758c096b 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalResponseTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -43,7 +44,7 @@ public class RankEvalResponseTests extends ESTestCase { int numberOfUnknownDocs = randomIntBetween(0, 5); List unknownDocs = new ArrayList<>(numberOfUnknownDocs); for (int d = 0; d < numberOfUnknownDocs; d++) { - unknownDocs.add(DocumentKeyTests.createRandomRatedDocumentKey()); + unknownDocs.add(new DocumentKey(randomAlphaOfLength(10), randomAlphaOfLength(10))); } EvalQueryQuality evalQuality = new EvalQueryQuality(id, randomDoubleBetween(0.0, 1.0, true)); @@ -65,12 +66,9 @@ public class RankEvalResponseTests extends ESTestCase { try (StreamInput in = output.bytes().streamInput()) { RankEvalResponse deserializedResponse = new RankEvalResponse(); deserializedResponse.readFrom(in); - assertEquals(randomResponse.getQualityLevel(), - deserializedResponse.getQualityLevel(), Double.MIN_VALUE); - assertEquals(randomResponse.getPartialResults(), - deserializedResponse.getPartialResults()); - assertEquals(randomResponse.getFailures().keySet(), - deserializedResponse.getFailures().keySet()); + assertEquals(randomResponse.getQualityLevel(), deserializedResponse.getQualityLevel(), Double.MIN_VALUE); + assertEquals(randomResponse.getPartialResults(), deserializedResponse.getPartialResults()); + assertEquals(randomResponse.getFailures().keySet(), deserializedResponse.getFailures().keySet()); assertNotSame(randomResponse, deserializedResponse); assertEquals(-1, in.read()); } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java index 40624f12dab..6197efb1b04 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java @@ -45,6 +45,8 @@ import java.util.Map; import java.util.Map.Entry; import java.util.function.Supplier; +import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; + public class RankEvalSpecTests extends ESTestCase { private static List randomList(Supplier randomSupplier) { @@ -57,9 +59,9 @@ public class RankEvalSpecTests extends ESTestCase { } private static RankEvalSpec createTestItem() throws IOException { - RankedListQualityMetric metric; + EvaluationMetric metric; if (randomBoolean()) { - metric = PrecisionTests.createTestItem(); + metric = PrecisionAtKTests.createTestItem(); } else { metric = DiscountedCumulativeGainTests.createTestItem(); } @@ -111,41 +113,30 @@ public class RankEvalSpecTests extends ESTestCase { public void testSerialization() throws IOException { RankEvalSpec original = createTestItem(); - - List namedWriteables = new ArrayList<>(); - namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, PrecisionAtK.NAME, PrecisionAtK::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGain.NAME, - DiscountedCumulativeGain::new)); - namedWriteables - .add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new)); - - RankEvalSpec deserialized = RankEvalTestHelper.copy(original, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables)); + RankEvalSpec deserialized = copy(original); assertEquals(deserialized, original); assertEquals(deserialized.hashCode(), original.hashCode()); assertNotSame(deserialized, original); } - public void testEqualsAndHash() throws IOException { - RankEvalSpec testItem = createTestItem(); - + private static RankEvalSpec copy(RankEvalSpec original) throws IOException { List namedWriteables = new ArrayList<>(); namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, PrecisionAtK.NAME, PrecisionAtK::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGain.NAME, - DiscountedCumulativeGain::new)); - namedWriteables - .add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new)); - - RankEvalSpec mutant = RankEvalTestHelper.copy(testItem, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables)); - RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(mutant), - RankEvalTestHelper.copy(testItem, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables))); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new)); + return ESTestCase.copyWriteable(original, new NamedWriteableRegistry(namedWriteables), RankEvalSpec::new); } - private static RankEvalSpec mutateTestItem(RankEvalSpec mutant) { - Collection ratedRequests = mutant.getRatedRequests(); - RankedListQualityMetric metric = mutant.getMetric(); - Map templates = mutant.getTemplates(); + public void testEqualsAndHash() throws IOException { + checkEqualsAndHashCode(createTestItem(), RankEvalSpecTests::copy, RankEvalSpecTests::mutateTestItem); + } + + private static RankEvalSpec mutateTestItem(RankEvalSpec original) { + List ratedRequests = new ArrayList<>(original.getRatedRequests()); + EvaluationMetric metric = original.getMetric(); + Map templates = original.getTemplates(); int mutate = randomIntBetween(0, 2); switch (mutate) { @@ -177,7 +168,7 @@ public class RankEvalSpecTests extends ESTestCase { } public void testMissingRatedRequestsFailsParsing() { - RankedListQualityMetric metric = new PrecisionAtK(); + EvaluationMetric metric = new PrecisionAtK(); expectThrows(IllegalStateException.class, () -> new RankEvalSpec(new ArrayList<>(), metric)); expectThrows(IllegalStateException.class, () -> new RankEvalSpec(null, metric)); } @@ -189,7 +180,7 @@ public class RankEvalSpecTests extends ESTestCase { } public void testMissingTemplateAndSearchRequestFailsParsing() { - List ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1)); + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); Map params = new HashMap<>(); params.put("key", "value"); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalTestHelper.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalTestHelper.java deleted file mode 100644 index 088db7df814..00000000000 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalTestHelper.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.index.rankeval; - -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.Writeable; - -import java.io.IOException; -import java.util.Collections; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.not; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; - -// TODO replace by infra from ESTestCase -public class RankEvalTestHelper { - - public static void testHashCodeAndEquals(T testItem, T mutation, T secondCopy) { - assertFalse("testItem is equal to null", testItem.equals(null)); - assertFalse("testItem is equal to incompatible type", testItem.equals("")); - assertTrue("testItem is not equal to self", testItem.equals(testItem)); - assertThat("same testItem's hashcode returns different values if called multiple times", - testItem.hashCode(), equalTo(testItem.hashCode())); - - assertThat("different testItem should not be equal", mutation, not(equalTo(testItem))); - - assertNotSame("testItem copy is not same as original", testItem, secondCopy); - assertTrue("testItem is not equal to its copy", testItem.equals(secondCopy)); - assertTrue("equals is not symmetric", secondCopy.equals(testItem)); - assertThat("testItem copy's hashcode is different from original hashcode", - secondCopy.hashCode(), equalTo(testItem.hashCode())); - } - - /** - * Make a deep copy of an object by running it through a BytesStreamOutput - * - * @param original - * the original object - * @param reader - * a function able to create a new copy of this type - * @return a new copy of the original object - */ - public static T copy(T original, Writeable.Reader reader) - throws IOException { - return copy(original, reader, new NamedWriteableRegistry(Collections.emptyList())); - } - - /** - * Make a deep copy of an object by running it through a BytesStreamOutput - * - * @param original - * the original object - * @param reader - * a function able to create a new copy of this type - * @param namedWriteableRegistry - * must be non-empty if the object itself or nested object - * implement {@link NamedWriteable} - * @return a new copy of the original object - */ - public static T copy(T original, Writeable.Reader reader, - NamedWriteableRegistry namedWriteableRegistry) throws IOException { - try (BytesStreamOutput output = new BytesStreamOutput()) { - original.writeTo(output); - try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), - namedWriteableRegistry)) { - return reader.read(in); - } - } - } -} diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedDocumentTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedDocumentTests.java index 672d464a386..5ec9692d83a 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedDocumentTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedDocumentTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; @@ -27,15 +28,14 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.util.Collections; + +import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; public class RatedDocumentTests extends ESTestCase { public static RatedDocument createRatedDocument() { - String index = randomAlphaOfLength(10); - String docId = randomAlphaOfLength(10); - int rating = randomInt(); - - return new RatedDocument(index, docId, rating); + return new RatedDocument(randomAlphaOfLength(10), randomAlphaOfLength(10), randomInt()); } public void testXContentParsing() throws IOException { @@ -52,22 +52,17 @@ public class RatedDocumentTests extends ESTestCase { public void testSerialization() throws IOException { RatedDocument original = createRatedDocument(); - RatedDocument deserialized = RankEvalTestHelper.copy(original, RatedDocument::new); + RatedDocument deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()), + RatedDocument::new); assertEquals(deserialized, original); assertEquals(deserialized.hashCode(), original.hashCode()); assertNotSame(deserialized, original); } public void testEqualsAndHash() throws IOException { - RatedDocument testItem = createRatedDocument(); - RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), RankEvalTestHelper.copy(testItem, RatedDocument::new)); - } - - public void testInvalidParsing() { - expectThrows(IllegalArgumentException.class, () -> new RatedDocument(null, "abc", 10)); - expectThrows(IllegalArgumentException.class, () -> new RatedDocument("", "abc", 10)); - expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", "", 10)); - expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", null, 10)); + checkEqualsAndHashCode(createRatedDocument(), original -> { + return new RatedDocument(original.getIndex(), original.getDocID(), original.getRating()); + }, RatedDocumentTests::mutateTestItem); } private static RatedDocument mutateTestItem(RatedDocument original) { diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java index 58a9f998183..53846406d9c 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java @@ -47,6 +47,7 @@ import java.util.stream.Stream; import static java.util.Collections.emptyList; import static java.util.stream.Collectors.toList; +import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; public class RatedRequestsTests extends ESTestCase { @@ -140,32 +141,26 @@ public class RatedRequestsTests extends ESTestCase { for (int i = 0; i < size; i++) { indices.add(randomAlphaOfLengthBetween(0, 50)); } - RatedRequest original = createTestItem(indices, randomBoolean()); - - List namedWriteables = new ArrayList<>(); - namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)); - - RatedRequest deserialized = RankEvalTestHelper.copy(original, RatedRequest::new, new NamedWriteableRegistry(namedWriteables)); + RatedRequest deserialized = copy(original); assertEquals(deserialized, original); assertEquals(deserialized.hashCode(), original.hashCode()); assertNotSame(deserialized, original); } + private static RatedRequest copy(RatedRequest original) throws IOException { + List namedWriteables = new ArrayList<>(); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)); + return ESTestCase.copyWriteable(original, new NamedWriteableRegistry(namedWriteables), RatedRequest::new); + } + public void testEqualsAndHash() throws IOException { List indices = new ArrayList<>(); int size = randomIntBetween(0, 20); for (int i = 0; i < size; i++) { indices.add(randomAlphaOfLengthBetween(0, 50)); } - - RatedRequest testItem = createTestItem(indices, randomBoolean()); - - List namedWriteables = new ArrayList<>(); - namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)); - - RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), - RankEvalTestHelper.copy(testItem, RatedRequest::new, new NamedWriteableRegistry(namedWriteables))); + checkEqualsAndHashCode(createTestItem(indices, randomBoolean()), RatedRequestsTests::copy, RatedRequestsTests::mutateTestItem); } private static RatedRequest mutateTestItem(RatedRequest original) { @@ -220,8 +215,7 @@ public class RatedRequestsTests extends ESTestCase { } public void testDuplicateRatedDocThrowsException() { - List ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1), - new RatedDocument(new DocumentKey("index1", "id1"), 5)); + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1), new RatedDocument("index1", "id1", 5)); // search request set, no summary fields IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, @@ -237,45 +231,45 @@ public class RatedRequestsTests extends ESTestCase { } public void testNullSummaryFieldsTreatment() { - List ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1)); + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); RatedRequest request = new RatedRequest("id", ratedDocs, new SearchSourceBuilder()); expectThrows(IllegalArgumentException.class, () -> request.setSummaryFields(null)); } public void testNullParamsTreatment() { - List ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1)); + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); RatedRequest request = new RatedRequest("id", ratedDocs, new SearchSourceBuilder(), null, null); assertNotNull(request.getParams()); } public void testSettingParamsAndRequestThrows() { - List ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1)); + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); Map params = new HashMap<>(); params.put("key", "value"); expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder(), params, null)); } public void testSettingNeitherParamsNorRequestThrows() { - List ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1)); + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, null)); expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, new HashMap<>(), "templateId")); } public void testSettingParamsWithoutTemplateIdThrows() { - List ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1)); + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); Map params = new HashMap<>(); params.put("key", "value"); expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, params, null)); } public void testSettingTemplateIdAndRequestThrows() { - List ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1)); + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder(), null, "templateId")); } public void testSettingTemplateIdNoParamsThrows() { - List ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1)); + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, null, "templateId")); } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedSearchHitTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedSearchHitTests.java index 3899a2e2029..cf66b0b7797 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedSearchHitTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedSearchHitTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.text.Text; import org.elasticsearch.search.SearchHit; import org.elasticsearch.test.ESTestCase; @@ -27,6 +28,8 @@ import java.io.IOException; import java.util.Collections; import java.util.Optional; +import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; + public class RatedSearchHitTests extends ESTestCase { public static RatedSearchHit randomRatedSearchHit() { @@ -57,15 +60,18 @@ public class RatedSearchHitTests extends ESTestCase { public void testSerialization() throws IOException { RatedSearchHit original = randomRatedSearchHit(); - RatedSearchHit deserialized = RankEvalTestHelper.copy(original, RatedSearchHit::new); + RatedSearchHit deserialized = copy(original); assertEquals(deserialized, original); assertEquals(deserialized.hashCode(), original.hashCode()); assertNotSame(deserialized, original); } public void testEqualsAndHash() throws IOException { - RatedSearchHit testItem = randomRatedSearchHit(); - RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), - RankEvalTestHelper.copy(testItem, RatedSearchHit::new)); + checkEqualsAndHashCode(randomRatedSearchHit(), RatedSearchHitTests::copy, RatedSearchHitTests::mutateTestItem); } + + private static RatedSearchHit copy(RatedSearchHit original) throws IOException { + return ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()), RatedSearchHit::new); + } + } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/TestRatingEnum.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/TestRatingEnum.java new file mode 100644 index 00000000000..ea44c215d92 --- /dev/null +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/TestRatingEnum.java @@ -0,0 +1,24 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.rankeval; + +enum TestRatingEnum { + IRRELEVANT, RELEVANT; +} \ No newline at end of file diff --git a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/30_failures.yml b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/30_failures.yml index 0df0110993a..130da28f3b1 100644 --- a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/30_failures.yml +++ b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/30_failures.yml @@ -35,4 +35,4 @@ - match: { rank_eval.details.amsterdam_query.unknown_docs: [ ]} - match: { rank_eval.details.amsterdam_query.metric_details: {"relevant_docs_retrieved": 1, "docs_retrieved": 1}} - - is_true: rank_eval.failures.invalid_queryy + - is_true: rank_eval.failures.invalid_query