From 72d0de4197d6370873f14f8a15396dc4bd30652d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= <10398885+cbuescher@users.noreply.github.com> Date: Mon, 4 Dec 2017 10:54:03 +0100 Subject: [PATCH] Add search window parameter k to MRR and DCG metric (#27595) --- .../rankeval/DiscountedCumulativeGain.java | 47 +++++-- .../index/rankeval/MeanReciprocalRank.java | 50 +++++--- .../index/rankeval/PrecisionAtK.java | 6 +- .../DiscountedCumulativeGainTests.java | 62 ++++++--- .../rankeval/MeanReciprocalRankTests.java | 34 ++++- .../index/rankeval/RankEvalRequestIT.java | 118 +++++++++++++++--- 6 files changed, 248 insertions(+), 69 deletions(-) diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java index 141d45c274b..64d4ada0dc1 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java @@ -33,21 +33,29 @@ import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings; /** - * Metric implementing Discounted Cumulative Gain (https://en.wikipedia.org/wiki/Discounted_cumulative_gain).
+ * Metric implementing Discounted Cumulative Gain. * The `normalize` parameter can be set to calculate the normalized NDCG (set to false by default).
* The optional `unknown_doc_rating` parameter can be used to specify a default rating for unlabeled documents. + * @see Discounted Cumulative Gain
*/ public class DiscountedCumulativeGain implements EvaluationMetric { /** If set to true, the dcg will be normalized (ndcg) */ private final boolean normalize; + /** the default search window size */ + private static final int DEFAULT_K = 10; + + /** the search window size */ + private final int k; + /** * Optional. If set, this will be the rating for docs that are unrated in the ranking evaluation request */ @@ -57,7 +65,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric { private static final double LOG2 = Math.log(2.0); public DiscountedCumulativeGain() { - this(false, null); + this(false, null, DEFAULT_K); } /** @@ -65,23 +73,27 @@ public class DiscountedCumulativeGain implements EvaluationMetric { * If set to true, dcg will be normalized (ndcg) See * https://en.wikipedia.org/wiki/Discounted_cumulative_gain * @param unknownDocRating - * the rating for docs the user hasn't supplied an explicit + * the rating for documents the user hasn't supplied an explicit * rating for + * @param k the search window size all request use. */ - public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating) { + public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating, int k) { this.normalize = normalize; this.unknownDocRating = unknownDocRating; + this.k = k; } DiscountedCumulativeGain(StreamInput in) throws IOException { normalize = in.readBoolean(); unknownDocRating = in.readOptionalVInt(); + k = in.readVInt(); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(normalize); out.writeOptionalVInt(unknownDocRating); + out.writeVInt(k); } @Override @@ -89,13 +101,14 @@ public class DiscountedCumulativeGain implements EvaluationMetric { return NAME; } - /** - * check whether this metric computes only dcg or "normalized" ndcg - */ - public boolean getNormalize() { + boolean getNormalize() { return this.normalize; } + int getK() { + return this.k; + } + /** * get the rating used for unrated documents */ @@ -103,6 +116,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric { return this.unknownDocRating; } + + @Override + public Optional forcedSearchSize() { + return Optional.of(k); + } + @Override public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List ratedDocs) { @@ -142,17 +161,21 @@ public class DiscountedCumulativeGain implements EvaluationMetric { return dcg; } + private static final ParseField K_FIELD = new ParseField("k"); private static final ParseField NORMALIZE_FIELD = new ParseField("normalize"); private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("dcg_at", args -> { Boolean normalized = (Boolean) args[0]; - return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1]); + Integer optK = (Integer) args[2]; + return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1], + optK == null ? DEFAULT_K : optK); }); static { PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD); PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD); + PARSER.declareInt(optionalConstructorArg(), K_FIELD); } public static DiscountedCumulativeGain fromXContent(XContentParser parser) { @@ -167,6 +190,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric { if (unknownDocRating != null) { builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating); } + builder.field(K_FIELD.getPreferredName(), this.k); builder.endObject(); builder.endObject(); return builder; @@ -182,11 +206,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric { } DiscountedCumulativeGain other = (DiscountedCumulativeGain) obj; return Objects.equals(normalize, other.normalize) - && Objects.equals(unknownDocRating, other.unknownDocRating); + && Objects.equals(unknownDocRating, other.unknownDocRating) + && Objects.equals(k, other.k); } @Override public final int hashCode() { - return Objects.hash(normalize, unknownDocRating); + return Objects.hash(normalize, unknownDocRating, k); } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java index 057bff6e147..a74fd8da3e6 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java @@ -42,37 +42,57 @@ import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRati */ public class MeanReciprocalRank implements EvaluationMetric { - private static final int DEFAULT_RATING_THRESHOLD = 1; - public static final String NAME = "mean_reciprocal_rank"; - /** ratings equal or above this value will be considered relevant. */ + private static final int DEFAULT_RATING_THRESHOLD = 1; + private static final int DEFAULT_K = 10; + + /** the search window size */ + private final int k; + + /** ratings equal or above this value will be considered relevant */ private final int relevantRatingThreshhold; public MeanReciprocalRank() { - this(DEFAULT_RATING_THRESHOLD); + this(DEFAULT_RATING_THRESHOLD, DEFAULT_K); } MeanReciprocalRank(StreamInput in) throws IOException { this.relevantRatingThreshhold = in.readVInt(); + this.k = in.readVInt(); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(relevantRatingThreshhold); + out.writeVInt(this.relevantRatingThreshhold); + out.writeVInt(this.k); } /** * Metric implementing Mean Reciprocal Rank (https://en.wikipedia.org/wiki/Mean_reciprocal_rank).
- * @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevalnt". Defaults to 1. + * @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevant". Defaults to 1. + * @param k the search window size all request use. */ - public MeanReciprocalRank(int relevantRatingThreshold) { + public MeanReciprocalRank(int relevantRatingThreshold, int k) { if (relevantRatingThreshold < 0) { throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer."); } + if (k <= 0) { + throw new IllegalArgumentException("Window size k must be positive."); + } + this.k = k; this.relevantRatingThreshhold = relevantRatingThreshold; } + int getK() { + return this.k; + } + + @Override + public Optional forcedSearchSize() { + return Optional.of(k); + } + @Override public String getWriteableName() { return NAME; @@ -113,18 +133,18 @@ public class MeanReciprocalRank implements EvaluationMetric { } private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold"); + private static final ParseField K_FIELD = new ParseField("k"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("reciprocal_rank", args -> { Integer optionalThreshold = (Integer) args[0]; - if (optionalThreshold == null) { - return new MeanReciprocalRank(); - } else { - return new MeanReciprocalRank(optionalThreshold); - } + Integer optionalK = (Integer) args[1]; + return new MeanReciprocalRank(optionalThreshold == null ? DEFAULT_RATING_THRESHOLD : optionalThreshold, + optionalK == null ? DEFAULT_K : optionalK); }); static { PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD); + PARSER.declareInt(optionalConstructorArg(), K_FIELD); } public static MeanReciprocalRank fromXContent(XContentParser parser) { @@ -136,6 +156,7 @@ public class MeanReciprocalRank implements EvaluationMetric { builder.startObject(); builder.startObject(NAME); builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold); + builder.field(K_FIELD.getPreferredName(), this.k); builder.endObject(); builder.endObject(); return builder; @@ -150,12 +171,13 @@ public class MeanReciprocalRank implements EvaluationMetric { return false; } MeanReciprocalRank other = (MeanReciprocalRank) obj; - return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold); + return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold) + && Objects.equals(k, other.k); } @Override public final int hashCode() { - return Objects.hash(relevantRatingThreshhold); + return Objects.hash(relevantRatingThreshhold, k); } static class Breakdown implements MetricDetails { diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java index 4beeeea2b40..63bdcb7307d 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java @@ -96,7 +96,7 @@ public class PrecisionAtK implements EvaluationMetric { Integer k = (Integer) args[2]; return new PrecisionAtK(threshHold == null ? 1 : threshHold, ignoreUnlabeled == null ? false : ignoreUnlabeled, - k == null ? 10 : k); + k == null ? DEFAULT_K : k); }); static { @@ -111,6 +111,10 @@ public class PrecisionAtK implements EvaluationMetric { k = in.readVInt(); } + int getK() { + return this.k; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeVInt(relevantRatingThreshhold); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java index f3c38b7ae64..00f8a3018d4 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java @@ -42,13 +42,18 @@ import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashC public class DiscountedCumulativeGainTests extends ESTestCase { + static final double EXPECTED_DCG = 13.84826362927298; + static final double EXPECTED_IDCG = 14.595390756454922; + static final double EXPECTED_NDCG = EXPECTED_DCG / EXPECTED_IDCG; + private static final double DELTA = 10E-16; + /** * Assuming the docs are ranked in the following order: * * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) * ------------------------------------------------------------------------------------------- - * 1 | 3 | 7.0 | 1.0 | 7.0 2 |  - * 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 + * 1 | 3 | 7.0 | 1.0 | 7.0 | 7.0 |  + * 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 * 3 | 3 | 7.0 | 2.0 | 3.5 * 4 | 0 | 0.0 | 2.321928094887362 | 0.0 * 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163 @@ -66,7 +71,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase { hits[i].shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0, null)); } DiscountedCumulativeGain dcg = new DiscountedCumulativeGain(); - assertEquals(13.84826362927298, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001); + assertEquals(EXPECTED_DCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA); /** * Check with normalization: to get the maximal possible dcg, sort documents by @@ -83,8 +88,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase { * * idcg = 14.595390756454922 (sum of last column) */ - dcg = new DiscountedCumulativeGain(true, null); - assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001); + dcg = new DiscountedCumulativeGain(true, null, 10); + assertEquals(EXPECTED_NDCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA); } /** @@ -117,7 +122,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase { } DiscountedCumulativeGain dcg = new DiscountedCumulativeGain(); EvalQueryQuality result = dcg.evaluate("id", hits, rated); - assertEquals(12.779642067948913, result.getQualityLevel(), 0.00001); + assertEquals(12.779642067948913, result.getQualityLevel(), DELTA); assertEquals(2, filterUnknownDocuments(result.getHitsAndRatings()).size()); /** @@ -135,8 +140,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase { * * idcg = 13.347184833073591 (sum of last column) */ - dcg = new DiscountedCumulativeGain(true, null); - assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001); + dcg = new DiscountedCumulativeGain(true, null, 10); + assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA); } /** @@ -174,7 +179,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase { } DiscountedCumulativeGain dcg = new DiscountedCumulativeGain(); EvalQueryQuality result = dcg.evaluate("id", hits, ratedDocs); - assertEquals(12.392789260714371, result.getQualityLevel(), 0.00001); + assertEquals(12.392789260714371, result.getQualityLevel(), DELTA); assertEquals(1, filterUnknownDocuments(result.getHitsAndRatings()).size()); /** @@ -193,16 +198,27 @@ public class DiscountedCumulativeGainTests extends ESTestCase { * * idcg = 13.347184833073591 (sum of last column) */ - dcg = new DiscountedCumulativeGain(true, null); - assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001); + dcg = new DiscountedCumulativeGain(true, null, 10); + assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), DELTA); } public void testParseFromXContent() throws IOException { - String xContent = " { \"unknown_doc_rating\": 2, \"normalize\": true }"; + assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true, \"k\" : 15 }", 2, true, 15); + assertParsedCorrect("{ \"normalize\": false, \"k\" : 15 }", null, false, 15); + assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"k\" : 15 }", 2, false, 15); + assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true }", 2, true, 10); + assertParsedCorrect("{ \"normalize\": true }", null, true, 10); + assertParsedCorrect("{ \"k\": 23 }", null, false, 23); + assertParsedCorrect("{ \"unknown_doc_rating\": 2 }", 2, false, 10); + } + + private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRating, boolean expectedNormalize, int expectedK) + throws IOException { try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser); - assertEquals(2, dcgAt.getUnknownDocRating().intValue()); - assertEquals(true, dcgAt.getNormalize()); + assertEquals(expectedUnknownDocRating, dcgAt.getUnknownDocRating()); + assertEquals(expectedNormalize, dcgAt.getNormalize()); + assertEquals(expectedK, dcgAt.getK()); } } @@ -210,7 +226,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase { boolean normalize = randomBoolean(); Integer unknownDocRating = new Integer(randomIntBetween(0, 1000)); - return new DiscountedCumulativeGain(normalize, unknownDocRating); + return new DiscountedCumulativeGain(normalize, unknownDocRating, 10); } public void testXContentRoundtrip() throws IOException { @@ -238,16 +254,22 @@ public class DiscountedCumulativeGainTests extends ESTestCase { public void testEqualsAndHash() throws IOException { checkEqualsAndHashCode(createTestItem(), original -> { - return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating()); + return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(), original.getK()); }, DiscountedCumulativeGainTests::mutateTestItem); } private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) { - if (randomBoolean()) { - return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating()); - } else { + switch (randomIntBetween(0, 2)) { + case 0: + return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating(), original.getK()); + case 1: return new DiscountedCumulativeGain(original.getNormalize(), - randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10))); + randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)), original.getK()); + case 2: + return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(), + randomValueOtherThan(original.getK(), () -> randomIntBetween(1, 10))); + default: + throw new IllegalArgumentException("mutation variant not allowed"); } } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java index 3a1b4939758..42f7e32671f 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java @@ -48,12 +48,28 @@ public class MeanReciprocalRankTests extends ESTestCase { try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser); assertEquals(1, mrr.getRelevantRatingThreshold()); + assertEquals(10, mrr.getK()); } xContent = "{ \"relevant_rating_threshold\": 2 }"; try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser); assertEquals(2, mrr.getRelevantRatingThreshold()); + assertEquals(10, mrr.getK()); + } + + xContent = "{ \"relevant_rating_threshold\": 2, \"k\" : 15 }"; + try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { + MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser); + assertEquals(2, mrr.getRelevantRatingThreshold()); + assertEquals(15, mrr.getK()); + } + + xContent = "{ \"k\" : 15 }"; + try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { + MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser); + assertEquals(1, mrr.getRelevantRatingThreshold()); + assertEquals(15, mrr.getK()); } } @@ -116,7 +132,7 @@ public class MeanReciprocalRankTests extends ESTestCase { rated.add(new RatedDocument("test", "4", 4)); SearchHit[] hits = createSearchHits(0, 5, "test"); - MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2); + MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2, 10); EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated); assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001); assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank()); @@ -167,7 +183,7 @@ public class MeanReciprocalRankTests extends ESTestCase { } static MeanReciprocalRank createTestItem() { - return new MeanReciprocalRank(randomIntBetween(0, 20)); + return new MeanReciprocalRank(randomIntBetween(0, 20), randomIntBetween(1, 20)); } public void testSerialization() throws IOException { @@ -184,14 +200,22 @@ public class MeanReciprocalRankTests extends ESTestCase { } private static MeanReciprocalRank copy(MeanReciprocalRank testItem) { - return new MeanReciprocalRank(testItem.getRelevantRatingThreshold()); + return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK()); } private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) { - return new MeanReciprocalRank(randomValueOtherThan(testItem.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10))); + if (randomBoolean()) { + return new MeanReciprocalRank(testItem.getRelevantRatingThreshold() + 1, testItem.getK()); + } else { + return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK() + 1); + } } public void testInvalidRelevantThreshold() { - expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1)); + expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1, 1)); + } + + public void testInvalidK() { + expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(1, -1)); } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java index 2ab3e2d9a57..1c10da61fa1 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java @@ -20,18 +20,17 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.client.Client; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.test.ESIntegTestCase; import org.junit.Before; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map.Entry; import java.util.Set; @@ -64,13 +63,16 @@ public class RankEvalRequestIT extends ESIntegTestCase { refresh(); } + /** + * Test cases retrieves all six documents indexed above. The first part checks the Prec@10 calculation where + * all unlabeled docs are treated as "unrelevant". We average Prec@ metric across two search use cases, the + * first one that labels 4 out of the 6 documents as relevant, the second one with only one relevant document. + */ public void testPrecisionAtRequest() { - List indices = Arrays.asList(new String[] { "test" }); - List specifications = new ArrayList<>(); SearchSourceBuilder testQuery = new SearchSourceBuilder(); testQuery.query(new MatchAllQueryBuilder()); - testQuery.sort(FieldSortBuilder.DOC_FIELD_NAME); + testQuery.sort("_id"); RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", createRelevant("2", "3", "4", "5"), testQuery); amsterdamRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" })); @@ -79,12 +81,11 @@ public class RankEvalRequestIT extends ESIntegTestCase { RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("1"), testQuery); berlinRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" })); - specifications.add(berlinRequest); PrecisionAtK metric = new PrecisionAtK(1, false, 10); RankEvalSpec task = new RankEvalSpec(specifications, metric); - task.addIndices(indices); + task.addIndices(Collections.singletonList("test")); RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); @@ -92,6 +93,8 @@ public class RankEvalRequestIT extends ESIntegTestCase { RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()) .actionGet(); + // the expected Prec@ for the first query is 4/6 and the expected Prec@ for the + // second is 1/6, divided by 2 to get the average double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0; assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE); Set> entrySet = response.getPartialResults().entrySet(); @@ -129,14 +132,96 @@ public class RankEvalRequestIT extends ESIntegTestCase { // test that a different window size k affects the result metric = new PrecisionAtK(1, false, 3); task = new RankEvalSpec(specifications, metric); - task.addIndices(indices); + task.addIndices(Collections.singletonList("test")); builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); builder.setRankEvalSpec(task); response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); + // if we look only at top 3 documente, the expected P@3 for the first query is + // 2/3 and the expected Prec@ for the second is 1/3, divided by 2 to get the average expectedPrecision = (1.0 / 3.0 + 2.0 / 3.0) / 2.0; - assertEquals(0.5, response.getEvaluationResult(), Double.MIN_VALUE); + assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE); + } + + /** + * This test assumes we are using the same ratings as in {@link DiscountedCumulativeGainTests#testDCGAt()}. + * See details in that test case for how the expected values are calculated + */ + public void testDCGRequest() { + SearchSourceBuilder testQuery = new SearchSourceBuilder(); + testQuery.query(new MatchAllQueryBuilder()); + testQuery.sort("_id"); + + List specifications = new ArrayList<>(); + List ratedDocs = Arrays.asList( + new RatedDocument("test", "1", 3), + new RatedDocument("test", "2", 2), + new RatedDocument("test", "3", 3), + new RatedDocument("test", "4", 0), + new RatedDocument("test", "5", 1), + new RatedDocument("test", "6", 2)); + specifications.add(new RatedRequest("amsterdam_query", ratedDocs, testQuery)); + + DiscountedCumulativeGain metric = new DiscountedCumulativeGain(false, null, 10); + RankEvalSpec task = new RankEvalSpec(specifications, metric); + task.addIndices(Collections.singletonList("test")); + + RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); + builder.setRankEvalSpec(task); + + RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); + assertEquals(DiscountedCumulativeGainTests.EXPECTED_DCG, response.getEvaluationResult(), Double.MIN_VALUE); + + // test that a different window size k affects the result + metric = new DiscountedCumulativeGain(false, null, 3); + task = new RankEvalSpec(specifications, metric); + task.addIndices(Collections.singletonList("test")); + + builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); + builder.setRankEvalSpec(task); + + response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); + assertEquals(12.392789260714371, response.getEvaluationResult(), Double.MIN_VALUE); + } + + public void testMRRRequest() { + SearchSourceBuilder testQuery = new SearchSourceBuilder(); + testQuery.query(new MatchAllQueryBuilder()); + testQuery.sort("_id"); + + List specifications = new ArrayList<>(); + specifications.add(new RatedRequest("amsterdam_query", createRelevant("5"), testQuery)); + specifications.add(new RatedRequest("berlin_query", createRelevant("1"), testQuery)); + + MeanReciprocalRank metric = new MeanReciprocalRank(1, 10); + RankEvalSpec task = new RankEvalSpec(specifications, metric); + task.addIndices(Collections.singletonList("test")); + + RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); + builder.setRankEvalSpec(task); + + RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); + // the expected reciprocal rank for the amsterdam_query is 1/5 + // the expected reciprocal rank for the berlin_query is 1/1 + // dividing by 2 to get the average + double expectedMRR = (1.0 / 1.0 + 1.0 / 5.0) / 2.0; + assertEquals(expectedMRR, response.getEvaluationResult(), 0.0); + + // test that a different window size k affects the result + metric = new MeanReciprocalRank(1, 3); + task = new RankEvalSpec(specifications, metric); + task.addIndices(Collections.singletonList("test")); + + builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); + builder.setRankEvalSpec(task); + + response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); + // limiting to top 3 results, the amsterdam_query has no relevant document in it + // the reciprocal rank for the berlin_query is 1/1 + // dividing by 2 to get the average + expectedMRR = (1.0/ 1.0) / 2.0; + assertEquals(expectedMRR, response.getEvaluationResult(), 0.0); } /** @@ -162,16 +247,13 @@ public class RankEvalRequestIT extends ESIntegTestCase { RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK()); task.addIndices(indices); - try (Client client = client()) { - RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client, RankEvalAction.INSTANCE, new RankEvalRequest()); - builder.setRankEvalSpec(task); + RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); + builder.setRankEvalSpec(task); - RankEvalResponse response = client.execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); - assertEquals(1, response.getFailures().size()); - ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query")); - assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"", - rootCauses[0].getCause().toString()); - } + RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); + assertEquals(1, response.getFailures().size()); + ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query")); + assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"", rootCauses[0].getCause().toString()); } private static List createRelevant(String... docs) {