From 72d0de4197d6370873f14f8a15396dc4bd30652d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Christoph=20B=C3=BCscher?=
<10398885+cbuescher@users.noreply.github.com>
Date: Mon, 4 Dec 2017 10:54:03 +0100
Subject: [PATCH] Add search window parameter k to MRR and DCG metric (#27595)
---
.../rankeval/DiscountedCumulativeGain.java | 47 +++++--
.../index/rankeval/MeanReciprocalRank.java | 50 +++++---
.../index/rankeval/PrecisionAtK.java | 6 +-
.../DiscountedCumulativeGainTests.java | 62 ++++++---
.../rankeval/MeanReciprocalRankTests.java | 34 ++++-
.../index/rankeval/RankEvalRequestIT.java | 118 +++++++++++++++---
6 files changed, 248 insertions(+), 69 deletions(-)
diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java
index 141d45c274b..64d4ada0dc1 100644
--- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java
+++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java
@@ -33,21 +33,29 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
+import java.util.Optional;
import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
/**
- * Metric implementing Discounted Cumulative Gain (https://en.wikipedia.org/wiki/Discounted_cumulative_gain).
+ * Metric implementing Discounted Cumulative Gain.
* The `normalize` parameter can be set to calculate the normalized NDCG (set to false by default).
* The optional `unknown_doc_rating` parameter can be used to specify a default rating for unlabeled documents.
+ * @see Discounted Cumulative Gain
*/
public class DiscountedCumulativeGain implements EvaluationMetric {
/** If set to true, the dcg will be normalized (ndcg) */
private final boolean normalize;
+ /** the default search window size */
+ private static final int DEFAULT_K = 10;
+
+ /** the search window size */
+ private final int k;
+
/**
* Optional. If set, this will be the rating for docs that are unrated in the ranking evaluation request
*/
@@ -57,7 +65,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
private static final double LOG2 = Math.log(2.0);
public DiscountedCumulativeGain() {
- this(false, null);
+ this(false, null, DEFAULT_K);
}
/**
@@ -65,23 +73,27 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
* If set to true, dcg will be normalized (ndcg) See
* https://en.wikipedia.org/wiki/Discounted_cumulative_gain
* @param unknownDocRating
- * the rating for docs the user hasn't supplied an explicit
+ * the rating for documents the user hasn't supplied an explicit
* rating for
+ * @param k the search window size all request use.
*/
- public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating) {
+ public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating, int k) {
this.normalize = normalize;
this.unknownDocRating = unknownDocRating;
+ this.k = k;
}
DiscountedCumulativeGain(StreamInput in) throws IOException {
normalize = in.readBoolean();
unknownDocRating = in.readOptionalVInt();
+ k = in.readVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(normalize);
out.writeOptionalVInt(unknownDocRating);
+ out.writeVInt(k);
}
@Override
@@ -89,13 +101,14 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
return NAME;
}
- /**
- * check whether this metric computes only dcg or "normalized" ndcg
- */
- public boolean getNormalize() {
+ boolean getNormalize() {
return this.normalize;
}
+ int getK() {
+ return this.k;
+ }
+
/**
* get the rating used for unrated documents
*/
@@ -103,6 +116,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
return this.unknownDocRating;
}
+
+ @Override
+ public Optional forcedSearchSize() {
+ return Optional.of(k);
+ }
+
@Override
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
List ratedDocs) {
@@ -142,17 +161,21 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
return dcg;
}
+ private static final ParseField K_FIELD = new ParseField("k");
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("dcg_at",
args -> {
Boolean normalized = (Boolean) args[0];
- return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1]);
+ Integer optK = (Integer) args[2];
+ return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1],
+ optK == null ? DEFAULT_K : optK);
});
static {
PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD);
PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD);
+ PARSER.declareInt(optionalConstructorArg(), K_FIELD);
}
public static DiscountedCumulativeGain fromXContent(XContentParser parser) {
@@ -167,6 +190,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
if (unknownDocRating != null) {
builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating);
}
+ builder.field(K_FIELD.getPreferredName(), this.k);
builder.endObject();
builder.endObject();
return builder;
@@ -182,11 +206,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
}
DiscountedCumulativeGain other = (DiscountedCumulativeGain) obj;
return Objects.equals(normalize, other.normalize)
- && Objects.equals(unknownDocRating, other.unknownDocRating);
+ && Objects.equals(unknownDocRating, other.unknownDocRating)
+ && Objects.equals(k, other.k);
}
@Override
public final int hashCode() {
- return Objects.hash(normalize, unknownDocRating);
+ return Objects.hash(normalize, unknownDocRating, k);
}
}
diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java
index 057bff6e147..a74fd8da3e6 100644
--- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java
+++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MeanReciprocalRank.java
@@ -42,37 +42,57 @@ import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRati
*/
public class MeanReciprocalRank implements EvaluationMetric {
- private static final int DEFAULT_RATING_THRESHOLD = 1;
-
public static final String NAME = "mean_reciprocal_rank";
- /** ratings equal or above this value will be considered relevant. */
+ private static final int DEFAULT_RATING_THRESHOLD = 1;
+ private static final int DEFAULT_K = 10;
+
+ /** the search window size */
+ private final int k;
+
+ /** ratings equal or above this value will be considered relevant */
private final int relevantRatingThreshhold;
public MeanReciprocalRank() {
- this(DEFAULT_RATING_THRESHOLD);
+ this(DEFAULT_RATING_THRESHOLD, DEFAULT_K);
}
MeanReciprocalRank(StreamInput in) throws IOException {
this.relevantRatingThreshhold = in.readVInt();
+ this.k = in.readVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
- out.writeVInt(relevantRatingThreshhold);
+ out.writeVInt(this.relevantRatingThreshhold);
+ out.writeVInt(this.k);
}
/**
* Metric implementing Mean Reciprocal Rank (https://en.wikipedia.org/wiki/Mean_reciprocal_rank).
- * @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevalnt". Defaults to 1.
+ * @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevant". Defaults to 1.
+ * @param k the search window size all request use.
*/
- public MeanReciprocalRank(int relevantRatingThreshold) {
+ public MeanReciprocalRank(int relevantRatingThreshold, int k) {
if (relevantRatingThreshold < 0) {
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
}
+ if (k <= 0) {
+ throw new IllegalArgumentException("Window size k must be positive.");
+ }
+ this.k = k;
this.relevantRatingThreshhold = relevantRatingThreshold;
}
+ int getK() {
+ return this.k;
+ }
+
+ @Override
+ public Optional forcedSearchSize() {
+ return Optional.of(k);
+ }
+
@Override
public String getWriteableName() {
return NAME;
@@ -113,18 +133,18 @@ public class MeanReciprocalRank implements EvaluationMetric {
}
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
+ private static final ParseField K_FIELD = new ParseField("k");
private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("reciprocal_rank",
args -> {
Integer optionalThreshold = (Integer) args[0];
- if (optionalThreshold == null) {
- return new MeanReciprocalRank();
- } else {
- return new MeanReciprocalRank(optionalThreshold);
- }
+ Integer optionalK = (Integer) args[1];
+ return new MeanReciprocalRank(optionalThreshold == null ? DEFAULT_RATING_THRESHOLD : optionalThreshold,
+ optionalK == null ? DEFAULT_K : optionalK);
});
static {
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
+ PARSER.declareInt(optionalConstructorArg(), K_FIELD);
}
public static MeanReciprocalRank fromXContent(XContentParser parser) {
@@ -136,6 +156,7 @@ public class MeanReciprocalRank implements EvaluationMetric {
builder.startObject();
builder.startObject(NAME);
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
+ builder.field(K_FIELD.getPreferredName(), this.k);
builder.endObject();
builder.endObject();
return builder;
@@ -150,12 +171,13 @@ public class MeanReciprocalRank implements EvaluationMetric {
return false;
}
MeanReciprocalRank other = (MeanReciprocalRank) obj;
- return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold);
+ return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold)
+ && Objects.equals(k, other.k);
}
@Override
public final int hashCode() {
- return Objects.hash(relevantRatingThreshhold);
+ return Objects.hash(relevantRatingThreshhold, k);
}
static class Breakdown implements MetricDetails {
diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java
index 4beeeea2b40..63bdcb7307d 100644
--- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java
+++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java
@@ -96,7 +96,7 @@ public class PrecisionAtK implements EvaluationMetric {
Integer k = (Integer) args[2];
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
ignoreUnlabeled == null ? false : ignoreUnlabeled,
- k == null ? 10 : k);
+ k == null ? DEFAULT_K : k);
});
static {
@@ -111,6 +111,10 @@ public class PrecisionAtK implements EvaluationMetric {
k = in.readVInt();
}
+ int getK() {
+ return this.k;
+ }
+
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(relevantRatingThreshhold);
diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java
index f3c38b7ae64..00f8a3018d4 100644
--- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java
+++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java
@@ -42,13 +42,18 @@ import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashC
public class DiscountedCumulativeGainTests extends ESTestCase {
+ static final double EXPECTED_DCG = 13.84826362927298;
+ static final double EXPECTED_IDCG = 14.595390756454922;
+ static final double EXPECTED_NDCG = EXPECTED_DCG / EXPECTED_IDCG;
+ private static final double DELTA = 10E-16;
+
/**
* Assuming the docs are ranked in the following order:
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* -------------------------------------------------------------------------------------------
- * 1 | 3 | 7.0 | 1.0 | 7.0 2 |
- * 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
+ * 1 | 3 | 7.0 | 1.0 | 7.0 | 7.0 |
+ * 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5
* 4 | 0 | 0.0 | 2.321928094887362 | 0.0
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
@@ -66,7 +71,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
hits[i].shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0, null));
}
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
- assertEquals(13.84826362927298, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
+ assertEquals(EXPECTED_DCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
/**
* Check with normalization: to get the maximal possible dcg, sort documents by
@@ -83,8 +88,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
*
* idcg = 14.595390756454922 (sum of last column)
*/
- dcg = new DiscountedCumulativeGain(true, null);
- assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
+ dcg = new DiscountedCumulativeGain(true, null, 10);
+ assertEquals(EXPECTED_NDCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
}
/**
@@ -117,7 +122,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
}
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
EvalQueryQuality result = dcg.evaluate("id", hits, rated);
- assertEquals(12.779642067948913, result.getQualityLevel(), 0.00001);
+ assertEquals(12.779642067948913, result.getQualityLevel(), DELTA);
assertEquals(2, filterUnknownDocuments(result.getHitsAndRatings()).size());
/**
@@ -135,8 +140,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
*
* idcg = 13.347184833073591 (sum of last column)
*/
- dcg = new DiscountedCumulativeGain(true, null);
- assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
+ dcg = new DiscountedCumulativeGain(true, null, 10);
+ assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
}
/**
@@ -174,7 +179,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
}
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
EvalQueryQuality result = dcg.evaluate("id", hits, ratedDocs);
- assertEquals(12.392789260714371, result.getQualityLevel(), 0.00001);
+ assertEquals(12.392789260714371, result.getQualityLevel(), DELTA);
assertEquals(1, filterUnknownDocuments(result.getHitsAndRatings()).size());
/**
@@ -193,16 +198,27 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
*
* idcg = 13.347184833073591 (sum of last column)
*/
- dcg = new DiscountedCumulativeGain(true, null);
- assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001);
+ dcg = new DiscountedCumulativeGain(true, null, 10);
+ assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), DELTA);
}
public void testParseFromXContent() throws IOException {
- String xContent = " { \"unknown_doc_rating\": 2, \"normalize\": true }";
+ assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true, \"k\" : 15 }", 2, true, 15);
+ assertParsedCorrect("{ \"normalize\": false, \"k\" : 15 }", null, false, 15);
+ assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"k\" : 15 }", 2, false, 15);
+ assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true }", 2, true, 10);
+ assertParsedCorrect("{ \"normalize\": true }", null, true, 10);
+ assertParsedCorrect("{ \"k\": 23 }", null, false, 23);
+ assertParsedCorrect("{ \"unknown_doc_rating\": 2 }", 2, false, 10);
+ }
+
+ private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRating, boolean expectedNormalize, int expectedK)
+ throws IOException {
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser);
- assertEquals(2, dcgAt.getUnknownDocRating().intValue());
- assertEquals(true, dcgAt.getNormalize());
+ assertEquals(expectedUnknownDocRating, dcgAt.getUnknownDocRating());
+ assertEquals(expectedNormalize, dcgAt.getNormalize());
+ assertEquals(expectedK, dcgAt.getK());
}
}
@@ -210,7 +226,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
boolean normalize = randomBoolean();
Integer unknownDocRating = new Integer(randomIntBetween(0, 1000));
- return new DiscountedCumulativeGain(normalize, unknownDocRating);
+ return new DiscountedCumulativeGain(normalize, unknownDocRating, 10);
}
public void testXContentRoundtrip() throws IOException {
@@ -238,16 +254,22 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
public void testEqualsAndHash() throws IOException {
checkEqualsAndHashCode(createTestItem(), original -> {
- return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating());
+ return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(), original.getK());
}, DiscountedCumulativeGainTests::mutateTestItem);
}
private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) {
- if (randomBoolean()) {
- return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating());
- } else {
+ switch (randomIntBetween(0, 2)) {
+ case 0:
+ return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating(), original.getK());
+ case 1:
return new DiscountedCumulativeGain(original.getNormalize(),
- randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)));
+ randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)), original.getK());
+ case 2:
+ return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(),
+ randomValueOtherThan(original.getK(), () -> randomIntBetween(1, 10)));
+ default:
+ throw new IllegalArgumentException("mutation variant not allowed");
}
}
}
diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java
index 3a1b4939758..42f7e32671f 100644
--- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java
+++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/MeanReciprocalRankTests.java
@@ -48,12 +48,28 @@ public class MeanReciprocalRankTests extends ESTestCase {
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(1, mrr.getRelevantRatingThreshold());
+ assertEquals(10, mrr.getK());
}
xContent = "{ \"relevant_rating_threshold\": 2 }";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(2, mrr.getRelevantRatingThreshold());
+ assertEquals(10, mrr.getK());
+ }
+
+ xContent = "{ \"relevant_rating_threshold\": 2, \"k\" : 15 }";
+ try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
+ MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
+ assertEquals(2, mrr.getRelevantRatingThreshold());
+ assertEquals(15, mrr.getK());
+ }
+
+ xContent = "{ \"k\" : 15 }";
+ try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
+ MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
+ assertEquals(1, mrr.getRelevantRatingThreshold());
+ assertEquals(15, mrr.getK());
}
}
@@ -116,7 +132,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
rated.add(new RatedDocument("test", "4", 4));
SearchHit[] hits = createSearchHits(0, 5, "test");
- MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2);
+ MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2, 10);
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated);
assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001);
assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
@@ -167,7 +183,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
}
static MeanReciprocalRank createTestItem() {
- return new MeanReciprocalRank(randomIntBetween(0, 20));
+ return new MeanReciprocalRank(randomIntBetween(0, 20), randomIntBetween(1, 20));
}
public void testSerialization() throws IOException {
@@ -184,14 +200,22 @@ public class MeanReciprocalRankTests extends ESTestCase {
}
private static MeanReciprocalRank copy(MeanReciprocalRank testItem) {
- return new MeanReciprocalRank(testItem.getRelevantRatingThreshold());
+ return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK());
}
private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) {
- return new MeanReciprocalRank(randomValueOtherThan(testItem.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)));
+ if (randomBoolean()) {
+ return new MeanReciprocalRank(testItem.getRelevantRatingThreshold() + 1, testItem.getK());
+ } else {
+ return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK() + 1);
+ }
}
public void testInvalidRelevantThreshold() {
- expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1));
+ expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1, 1));
+ }
+
+ public void testInvalidK() {
+ expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(1, -1));
}
}
diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java
index 2ab3e2d9a57..1c10da61fa1 100644
--- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java
+++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestIT.java
@@ -20,18 +20,17 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.ElasticsearchException;
-import org.elasticsearch.client.Client;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
-import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.junit.Before;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
+import java.util.Collections;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
@@ -64,13 +63,16 @@ public class RankEvalRequestIT extends ESIntegTestCase {
refresh();
}
+ /**
+ * Test cases retrieves all six documents indexed above. The first part checks the Prec@10 calculation where
+ * all unlabeled docs are treated as "unrelevant". We average Prec@ metric across two search use cases, the
+ * first one that labels 4 out of the 6 documents as relevant, the second one with only one relevant document.
+ */
public void testPrecisionAtRequest() {
- List indices = Arrays.asList(new String[] { "test" });
-
List specifications = new ArrayList<>();
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
- testQuery.sort(FieldSortBuilder.DOC_FIELD_NAME);
+ testQuery.sort("_id");
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query",
createRelevant("2", "3", "4", "5"), testQuery);
amsterdamRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
@@ -79,12 +81,11 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("1"),
testQuery);
berlinRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
-
specifications.add(berlinRequest);
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
- task.addIndices(indices);
+ task.addIndices(Collections.singletonList("test"));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
RankEvalAction.INSTANCE, new RankEvalRequest());
@@ -92,6 +93,8 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
.actionGet();
+ // the expected Prec@ for the first query is 4/6 and the expected Prec@ for the
+ // second is 1/6, divided by 2 to get the average
double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0;
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
Set> entrySet = response.getPartialResults().entrySet();
@@ -129,14 +132,96 @@ public class RankEvalRequestIT extends ESIntegTestCase {
// test that a different window size k affects the result
metric = new PrecisionAtK(1, false, 3);
task = new RankEvalSpec(specifications, metric);
- task.addIndices(indices);
+ task.addIndices(Collections.singletonList("test"));
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
+ // if we look only at top 3 documente, the expected P@3 for the first query is
+ // 2/3 and the expected Prec@ for the second is 1/3, divided by 2 to get the average
expectedPrecision = (1.0 / 3.0 + 2.0 / 3.0) / 2.0;
- assertEquals(0.5, response.getEvaluationResult(), Double.MIN_VALUE);
+ assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
+ }
+
+ /**
+ * This test assumes we are using the same ratings as in {@link DiscountedCumulativeGainTests#testDCGAt()}.
+ * See details in that test case for how the expected values are calculated
+ */
+ public void testDCGRequest() {
+ SearchSourceBuilder testQuery = new SearchSourceBuilder();
+ testQuery.query(new MatchAllQueryBuilder());
+ testQuery.sort("_id");
+
+ List specifications = new ArrayList<>();
+ List ratedDocs = Arrays.asList(
+ new RatedDocument("test", "1", 3),
+ new RatedDocument("test", "2", 2),
+ new RatedDocument("test", "3", 3),
+ new RatedDocument("test", "4", 0),
+ new RatedDocument("test", "5", 1),
+ new RatedDocument("test", "6", 2));
+ specifications.add(new RatedRequest("amsterdam_query", ratedDocs, testQuery));
+
+ DiscountedCumulativeGain metric = new DiscountedCumulativeGain(false, null, 10);
+ RankEvalSpec task = new RankEvalSpec(specifications, metric);
+ task.addIndices(Collections.singletonList("test"));
+
+ RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
+ builder.setRankEvalSpec(task);
+
+ RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
+ assertEquals(DiscountedCumulativeGainTests.EXPECTED_DCG, response.getEvaluationResult(), Double.MIN_VALUE);
+
+ // test that a different window size k affects the result
+ metric = new DiscountedCumulativeGain(false, null, 3);
+ task = new RankEvalSpec(specifications, metric);
+ task.addIndices(Collections.singletonList("test"));
+
+ builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
+ builder.setRankEvalSpec(task);
+
+ response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
+ assertEquals(12.392789260714371, response.getEvaluationResult(), Double.MIN_VALUE);
+ }
+
+ public void testMRRRequest() {
+ SearchSourceBuilder testQuery = new SearchSourceBuilder();
+ testQuery.query(new MatchAllQueryBuilder());
+ testQuery.sort("_id");
+
+ List specifications = new ArrayList<>();
+ specifications.add(new RatedRequest("amsterdam_query", createRelevant("5"), testQuery));
+ specifications.add(new RatedRequest("berlin_query", createRelevant("1"), testQuery));
+
+ MeanReciprocalRank metric = new MeanReciprocalRank(1, 10);
+ RankEvalSpec task = new RankEvalSpec(specifications, metric);
+ task.addIndices(Collections.singletonList("test"));
+
+ RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
+ builder.setRankEvalSpec(task);
+
+ RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
+ // the expected reciprocal rank for the amsterdam_query is 1/5
+ // the expected reciprocal rank for the berlin_query is 1/1
+ // dividing by 2 to get the average
+ double expectedMRR = (1.0 / 1.0 + 1.0 / 5.0) / 2.0;
+ assertEquals(expectedMRR, response.getEvaluationResult(), 0.0);
+
+ // test that a different window size k affects the result
+ metric = new MeanReciprocalRank(1, 3);
+ task = new RankEvalSpec(specifications, metric);
+ task.addIndices(Collections.singletonList("test"));
+
+ builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
+ builder.setRankEvalSpec(task);
+
+ response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
+ // limiting to top 3 results, the amsterdam_query has no relevant document in it
+ // the reciprocal rank for the berlin_query is 1/1
+ // dividing by 2 to get the average
+ expectedMRR = (1.0/ 1.0) / 2.0;
+ assertEquals(expectedMRR, response.getEvaluationResult(), 0.0);
}
/**
@@ -162,16 +247,13 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());
task.addIndices(indices);
- try (Client client = client()) {
- RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client, RankEvalAction.INSTANCE, new RankEvalRequest());
- builder.setRankEvalSpec(task);
+ RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
+ builder.setRankEvalSpec(task);
- RankEvalResponse response = client.execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
- assertEquals(1, response.getFailures().size());
- ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
- assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"",
- rootCauses[0].getCause().toString());
- }
+ RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
+ assertEquals(1, response.getFailures().size());
+ ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
+ assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"", rootCauses[0].getCause().toString());
}
private static List createRelevant(String... docs) {