Add search window parameter k to MRR and DCG metric (#27595)

This commit is contained in:
Christoph Büscher 2017-12-04 10:54:03 +01:00 committed by GitHub
parent 35688f6441
commit 72d0de4197
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 248 additions and 69 deletions

View File

@ -33,21 +33,29 @@ import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings; import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
/** /**
* Metric implementing Discounted Cumulative Gain (https://en.wikipedia.org/wiki/Discounted_cumulative_gain).<br> * Metric implementing Discounted Cumulative Gain.
* The `normalize` parameter can be set to calculate the normalized NDCG (set to <tt>false</tt> by default).<br> * The `normalize` parameter can be set to calculate the normalized NDCG (set to <tt>false</tt> by default).<br>
* The optional `unknown_doc_rating` parameter can be used to specify a default rating for unlabeled documents. * The optional `unknown_doc_rating` parameter can be used to specify a default rating for unlabeled documents.
* @see <a href="https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Discounted_Cumulative_Gain">Discounted Cumulative Gain</a><br>
*/ */
public class DiscountedCumulativeGain implements EvaluationMetric { public class DiscountedCumulativeGain implements EvaluationMetric {
/** If set to true, the dcg will be normalized (ndcg) */ /** If set to true, the dcg will be normalized (ndcg) */
private final boolean normalize; private final boolean normalize;
/** the default search window size */
private static final int DEFAULT_K = 10;
/** the search window size */
private final int k;
/** /**
* Optional. If set, this will be the rating for docs that are unrated in the ranking evaluation request * Optional. If set, this will be the rating for docs that are unrated in the ranking evaluation request
*/ */
@ -57,7 +65,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
private static final double LOG2 = Math.log(2.0); private static final double LOG2 = Math.log(2.0);
public DiscountedCumulativeGain() { public DiscountedCumulativeGain() {
this(false, null); this(false, null, DEFAULT_K);
} }
/** /**
@ -65,23 +73,27 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
* If set to true, dcg will be normalized (ndcg) See * If set to true, dcg will be normalized (ndcg) See
* https://en.wikipedia.org/wiki/Discounted_cumulative_gain * https://en.wikipedia.org/wiki/Discounted_cumulative_gain
* @param unknownDocRating * @param unknownDocRating
* the rating for docs the user hasn't supplied an explicit * the rating for documents the user hasn't supplied an explicit
* rating for * rating for
* @param k the search window size all request use.
*/ */
public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating) { public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating, int k) {
this.normalize = normalize; this.normalize = normalize;
this.unknownDocRating = unknownDocRating; this.unknownDocRating = unknownDocRating;
this.k = k;
} }
DiscountedCumulativeGain(StreamInput in) throws IOException { DiscountedCumulativeGain(StreamInput in) throws IOException {
normalize = in.readBoolean(); normalize = in.readBoolean();
unknownDocRating = in.readOptionalVInt(); unknownDocRating = in.readOptionalVInt();
k = in.readVInt();
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(normalize); out.writeBoolean(normalize);
out.writeOptionalVInt(unknownDocRating); out.writeOptionalVInt(unknownDocRating);
out.writeVInt(k);
} }
@Override @Override
@ -89,13 +101,14 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
return NAME; return NAME;
} }
/** boolean getNormalize() {
* check whether this metric computes only dcg or "normalized" ndcg
*/
public boolean getNormalize() {
return this.normalize; return this.normalize;
} }
int getK() {
return this.k;
}
/** /**
* get the rating used for unrated documents * get the rating used for unrated documents
*/ */
@ -103,6 +116,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
return this.unknownDocRating; return this.unknownDocRating;
} }
@Override
public Optional<Integer> forcedSearchSize() {
return Optional.of(k);
}
@Override @Override
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
List<RatedDocument> ratedDocs) { List<RatedDocument> ratedDocs) {
@ -142,17 +161,21 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
return dcg; return dcg;
} }
private static final ParseField K_FIELD = new ParseField("k");
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize"); private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating"); private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at", private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at",
args -> { args -> {
Boolean normalized = (Boolean) args[0]; Boolean normalized = (Boolean) args[0];
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1]); Integer optK = (Integer) args[2];
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1],
optK == null ? DEFAULT_K : optK);
}); });
static { static {
PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD); PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD);
PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD); PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD);
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
} }
public static DiscountedCumulativeGain fromXContent(XContentParser parser) { public static DiscountedCumulativeGain fromXContent(XContentParser parser) {
@ -167,6 +190,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
if (unknownDocRating != null) { if (unknownDocRating != null) {
builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating); builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating);
} }
builder.field(K_FIELD.getPreferredName(), this.k);
builder.endObject(); builder.endObject();
builder.endObject(); builder.endObject();
return builder; return builder;
@ -182,11 +206,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
} }
DiscountedCumulativeGain other = (DiscountedCumulativeGain) obj; DiscountedCumulativeGain other = (DiscountedCumulativeGain) obj;
return Objects.equals(normalize, other.normalize) return Objects.equals(normalize, other.normalize)
&& Objects.equals(unknownDocRating, other.unknownDocRating); && Objects.equals(unknownDocRating, other.unknownDocRating)
&& Objects.equals(k, other.k);
} }
@Override @Override
public final int hashCode() { public final int hashCode() {
return Objects.hash(normalize, unknownDocRating); return Objects.hash(normalize, unknownDocRating, k);
} }
} }

View File

@ -42,37 +42,57 @@ import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRati
*/ */
public class MeanReciprocalRank implements EvaluationMetric { public class MeanReciprocalRank implements EvaluationMetric {
private static final int DEFAULT_RATING_THRESHOLD = 1;
public static final String NAME = "mean_reciprocal_rank"; public static final String NAME = "mean_reciprocal_rank";
/** ratings equal or above this value will be considered relevant. */ private static final int DEFAULT_RATING_THRESHOLD = 1;
private static final int DEFAULT_K = 10;
/** the search window size */
private final int k;
/** ratings equal or above this value will be considered relevant */
private final int relevantRatingThreshhold; private final int relevantRatingThreshhold;
public MeanReciprocalRank() { public MeanReciprocalRank() {
this(DEFAULT_RATING_THRESHOLD); this(DEFAULT_RATING_THRESHOLD, DEFAULT_K);
} }
MeanReciprocalRank(StreamInput in) throws IOException { MeanReciprocalRank(StreamInput in) throws IOException {
this.relevantRatingThreshhold = in.readVInt(); this.relevantRatingThreshhold = in.readVInt();
this.k = in.readVInt();
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(relevantRatingThreshhold); out.writeVInt(this.relevantRatingThreshhold);
out.writeVInt(this.k);
} }
/** /**
* Metric implementing Mean Reciprocal Rank (https://en.wikipedia.org/wiki/Mean_reciprocal_rank).<br> * Metric implementing Mean Reciprocal Rank (https://en.wikipedia.org/wiki/Mean_reciprocal_rank).<br>
* @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevalnt". Defaults to 1. * @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevant". Defaults to 1.
* @param k the search window size all request use.
*/ */
public MeanReciprocalRank(int relevantRatingThreshold) { public MeanReciprocalRank(int relevantRatingThreshold, int k) {
if (relevantRatingThreshold < 0) { if (relevantRatingThreshold < 0) {
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer."); throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
} }
if (k <= 0) {
throw new IllegalArgumentException("Window size k must be positive.");
}
this.k = k;
this.relevantRatingThreshhold = relevantRatingThreshold; this.relevantRatingThreshhold = relevantRatingThreshold;
} }
int getK() {
return this.k;
}
@Override
public Optional<Integer> forcedSearchSize() {
return Optional.of(k);
}
@Override @Override
public String getWriteableName() { public String getWriteableName() {
return NAME; return NAME;
@ -113,18 +133,18 @@ public class MeanReciprocalRank implements EvaluationMetric {
} }
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold"); private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
private static final ParseField K_FIELD = new ParseField("k");
private static final ConstructingObjectParser<MeanReciprocalRank, Void> PARSER = new ConstructingObjectParser<>("reciprocal_rank", private static final ConstructingObjectParser<MeanReciprocalRank, Void> PARSER = new ConstructingObjectParser<>("reciprocal_rank",
args -> { args -> {
Integer optionalThreshold = (Integer) args[0]; Integer optionalThreshold = (Integer) args[0];
if (optionalThreshold == null) { Integer optionalK = (Integer) args[1];
return new MeanReciprocalRank(); return new MeanReciprocalRank(optionalThreshold == null ? DEFAULT_RATING_THRESHOLD : optionalThreshold,
} else { optionalK == null ? DEFAULT_K : optionalK);
return new MeanReciprocalRank(optionalThreshold);
}
}); });
static { static {
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD); PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
} }
public static MeanReciprocalRank fromXContent(XContentParser parser) { public static MeanReciprocalRank fromXContent(XContentParser parser) {
@ -136,6 +156,7 @@ public class MeanReciprocalRank implements EvaluationMetric {
builder.startObject(); builder.startObject();
builder.startObject(NAME); builder.startObject(NAME);
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold); builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
builder.field(K_FIELD.getPreferredName(), this.k);
builder.endObject(); builder.endObject();
builder.endObject(); builder.endObject();
return builder; return builder;
@ -150,12 +171,13 @@ public class MeanReciprocalRank implements EvaluationMetric {
return false; return false;
} }
MeanReciprocalRank other = (MeanReciprocalRank) obj; MeanReciprocalRank other = (MeanReciprocalRank) obj;
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold); return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold)
&& Objects.equals(k, other.k);
} }
@Override @Override
public final int hashCode() { public final int hashCode() {
return Objects.hash(relevantRatingThreshhold); return Objects.hash(relevantRatingThreshhold, k);
} }
static class Breakdown implements MetricDetails { static class Breakdown implements MetricDetails {

View File

@ -96,7 +96,7 @@ public class PrecisionAtK implements EvaluationMetric {
Integer k = (Integer) args[2]; Integer k = (Integer) args[2];
return new PrecisionAtK(threshHold == null ? 1 : threshHold, return new PrecisionAtK(threshHold == null ? 1 : threshHold,
ignoreUnlabeled == null ? false : ignoreUnlabeled, ignoreUnlabeled == null ? false : ignoreUnlabeled,
k == null ? 10 : k); k == null ? DEFAULT_K : k);
}); });
static { static {
@ -111,6 +111,10 @@ public class PrecisionAtK implements EvaluationMetric {
k = in.readVInt(); k = in.readVInt();
} }
int getK() {
return this.k;
}
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(relevantRatingThreshhold); out.writeVInt(relevantRatingThreshhold);

View File

@ -42,13 +42,18 @@ import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashC
public class DiscountedCumulativeGainTests extends ESTestCase { public class DiscountedCumulativeGainTests extends ESTestCase {
static final double EXPECTED_DCG = 13.84826362927298;
static final double EXPECTED_IDCG = 14.595390756454922;
static final double EXPECTED_NDCG = EXPECTED_DCG / EXPECTED_IDCG;
private static final double DELTA = 10E-16;
/** /**
* Assuming the docs are ranked in the following order: * Assuming the docs are ranked in the following order:
* *
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* ------------------------------------------------------------------------------------------- * -------------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0 | 7.0 2 |  * 1 | 3 | 7.0 | 1.0 | 7.0 | 7.0 | 
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 * 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5 * 3 | 3 | 7.0 | 2.0 | 3.5
* 4 | 0 | 0.0 | 2.321928094887362 | 0.0 * 4 | 0 | 0.0 | 2.321928094887362 | 0.0
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163 * 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
@ -66,7 +71,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
hits[i].shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0, null)); hits[i].shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0, null));
} }
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain(); DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
assertEquals(13.84826362927298, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001); assertEquals(EXPECTED_DCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
/** /**
* Check with normalization: to get the maximal possible dcg, sort documents by * Check with normalization: to get the maximal possible dcg, sort documents by
@ -83,8 +88,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
* *
* idcg = 14.595390756454922 (sum of last column) * idcg = 14.595390756454922 (sum of last column)
*/ */
dcg = new DiscountedCumulativeGain(true, null); dcg = new DiscountedCumulativeGain(true, null, 10);
assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001); assertEquals(EXPECTED_NDCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
} }
/** /**
@ -117,7 +122,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
} }
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain(); DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
EvalQueryQuality result = dcg.evaluate("id", hits, rated); EvalQueryQuality result = dcg.evaluate("id", hits, rated);
assertEquals(12.779642067948913, result.getQualityLevel(), 0.00001); assertEquals(12.779642067948913, result.getQualityLevel(), DELTA);
assertEquals(2, filterUnknownDocuments(result.getHitsAndRatings()).size()); assertEquals(2, filterUnknownDocuments(result.getHitsAndRatings()).size());
/** /**
@ -135,8 +140,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
* *
* idcg = 13.347184833073591 (sum of last column) * idcg = 13.347184833073591 (sum of last column)
*/ */
dcg = new DiscountedCumulativeGain(true, null); dcg = new DiscountedCumulativeGain(true, null, 10);
assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001); assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
} }
/** /**
@ -174,7 +179,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
} }
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain(); DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
EvalQueryQuality result = dcg.evaluate("id", hits, ratedDocs); EvalQueryQuality result = dcg.evaluate("id", hits, ratedDocs);
assertEquals(12.392789260714371, result.getQualityLevel(), 0.00001); assertEquals(12.392789260714371, result.getQualityLevel(), DELTA);
assertEquals(1, filterUnknownDocuments(result.getHitsAndRatings()).size()); assertEquals(1, filterUnknownDocuments(result.getHitsAndRatings()).size());
/** /**
@ -193,16 +198,27 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
* *
* idcg = 13.347184833073591 (sum of last column) * idcg = 13.347184833073591 (sum of last column)
*/ */
dcg = new DiscountedCumulativeGain(true, null); dcg = new DiscountedCumulativeGain(true, null, 10);
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001); assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), DELTA);
} }
public void testParseFromXContent() throws IOException { public void testParseFromXContent() throws IOException {
String xContent = " { \"unknown_doc_rating\": 2, \"normalize\": true }"; assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true, \"k\" : 15 }", 2, true, 15);
assertParsedCorrect("{ \"normalize\": false, \"k\" : 15 }", null, false, 15);
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"k\" : 15 }", 2, false, 15);
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true }", 2, true, 10);
assertParsedCorrect("{ \"normalize\": true }", null, true, 10);
assertParsedCorrect("{ \"k\": 23 }", null, false, 23);
assertParsedCorrect("{ \"unknown_doc_rating\": 2 }", 2, false, 10);
}
private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRating, boolean expectedNormalize, int expectedK)
throws IOException {
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser); DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser);
assertEquals(2, dcgAt.getUnknownDocRating().intValue()); assertEquals(expectedUnknownDocRating, dcgAt.getUnknownDocRating());
assertEquals(true, dcgAt.getNormalize()); assertEquals(expectedNormalize, dcgAt.getNormalize());
assertEquals(expectedK, dcgAt.getK());
} }
} }
@ -210,7 +226,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
boolean normalize = randomBoolean(); boolean normalize = randomBoolean();
Integer unknownDocRating = new Integer(randomIntBetween(0, 1000)); Integer unknownDocRating = new Integer(randomIntBetween(0, 1000));
return new DiscountedCumulativeGain(normalize, unknownDocRating); return new DiscountedCumulativeGain(normalize, unknownDocRating, 10);
} }
public void testXContentRoundtrip() throws IOException { public void testXContentRoundtrip() throws IOException {
@ -238,16 +254,22 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
public void testEqualsAndHash() throws IOException { public void testEqualsAndHash() throws IOException {
checkEqualsAndHashCode(createTestItem(), original -> { checkEqualsAndHashCode(createTestItem(), original -> {
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating()); return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(), original.getK());
}, DiscountedCumulativeGainTests::mutateTestItem); }, DiscountedCumulativeGainTests::mutateTestItem);
} }
private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) { private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) {
if (randomBoolean()) { switch (randomIntBetween(0, 2)) {
return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating()); case 0:
} else { return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating(), original.getK());
case 1:
return new DiscountedCumulativeGain(original.getNormalize(), return new DiscountedCumulativeGain(original.getNormalize(),
randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10))); randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)), original.getK());
case 2:
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(),
randomValueOtherThan(original.getK(), () -> randomIntBetween(1, 10)));
default:
throw new IllegalArgumentException("mutation variant not allowed");
} }
} }
} }

View File

@ -48,12 +48,28 @@ public class MeanReciprocalRankTests extends ESTestCase {
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser); MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(1, mrr.getRelevantRatingThreshold()); assertEquals(1, mrr.getRelevantRatingThreshold());
assertEquals(10, mrr.getK());
} }
xContent = "{ \"relevant_rating_threshold\": 2 }"; xContent = "{ \"relevant_rating_threshold\": 2 }";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser); MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(2, mrr.getRelevantRatingThreshold()); assertEquals(2, mrr.getRelevantRatingThreshold());
assertEquals(10, mrr.getK());
}
xContent = "{ \"relevant_rating_threshold\": 2, \"k\" : 15 }";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(2, mrr.getRelevantRatingThreshold());
assertEquals(15, mrr.getK());
}
xContent = "{ \"k\" : 15 }";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(1, mrr.getRelevantRatingThreshold());
assertEquals(15, mrr.getK());
} }
} }
@ -116,7 +132,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
rated.add(new RatedDocument("test", "4", 4)); rated.add(new RatedDocument("test", "4", 4));
SearchHit[] hits = createSearchHits(0, 5, "test"); SearchHit[] hits = createSearchHits(0, 5, "test");
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2); MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2, 10);
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated); EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated);
assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001); assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001);
assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank()); assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
@ -167,7 +183,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
} }
static MeanReciprocalRank createTestItem() { static MeanReciprocalRank createTestItem() {
return new MeanReciprocalRank(randomIntBetween(0, 20)); return new MeanReciprocalRank(randomIntBetween(0, 20), randomIntBetween(1, 20));
} }
public void testSerialization() throws IOException { public void testSerialization() throws IOException {
@ -184,14 +200,22 @@ public class MeanReciprocalRankTests extends ESTestCase {
} }
private static MeanReciprocalRank copy(MeanReciprocalRank testItem) { private static MeanReciprocalRank copy(MeanReciprocalRank testItem) {
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold()); return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK());
} }
private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) { private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) {
return new MeanReciprocalRank(randomValueOtherThan(testItem.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10))); if (randomBoolean()) {
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold() + 1, testItem.getK());
} else {
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK() + 1);
}
} }
public void testInvalidRelevantThreshold() { public void testInvalidRelevantThreshold() {
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1)); expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1, 1));
}
public void testInvalidK() {
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(1, -1));
} }
} }

View File

@ -20,18 +20,17 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.client.Client;
import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.ESIntegTestCase;
import org.junit.Before; import org.junit.Before;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Set; import java.util.Set;
@ -64,13 +63,16 @@ public class RankEvalRequestIT extends ESIntegTestCase {
refresh(); refresh();
} }
/**
* Test cases retrieves all six documents indexed above. The first part checks the Prec@10 calculation where
* all unlabeled docs are treated as "unrelevant". We average Prec@ metric across two search use cases, the
* first one that labels 4 out of the 6 documents as relevant, the second one with only one relevant document.
*/
public void testPrecisionAtRequest() { public void testPrecisionAtRequest() {
List<String> indices = Arrays.asList(new String[] { "test" });
List<RatedRequest> specifications = new ArrayList<>(); List<RatedRequest> specifications = new ArrayList<>();
SearchSourceBuilder testQuery = new SearchSourceBuilder(); SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder()); testQuery.query(new MatchAllQueryBuilder());
testQuery.sort(FieldSortBuilder.DOC_FIELD_NAME); testQuery.sort("_id");
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query",
createRelevant("2", "3", "4", "5"), testQuery); createRelevant("2", "3", "4", "5"), testQuery);
amsterdamRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" })); amsterdamRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
@ -79,12 +81,11 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("1"), RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("1"),
testQuery); testQuery);
berlinRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" })); berlinRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
specifications.add(berlinRequest); specifications.add(berlinRequest);
PrecisionAtK metric = new PrecisionAtK(1, false, 10); PrecisionAtK metric = new PrecisionAtK(1, false, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric); RankEvalSpec task = new RankEvalSpec(specifications, metric);
task.addIndices(indices); task.addIndices(Collections.singletonList("test"));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
RankEvalAction.INSTANCE, new RankEvalRequest()); RankEvalAction.INSTANCE, new RankEvalRequest());
@ -92,6 +93,8 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()) RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
.actionGet(); .actionGet();
// the expected Prec@ for the first query is 4/6 and the expected Prec@ for the
// second is 1/6, divided by 2 to get the average
double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0; double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0;
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE); assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
Set<Entry<String, EvalQueryQuality>> entrySet = response.getPartialResults().entrySet(); Set<Entry<String, EvalQueryQuality>> entrySet = response.getPartialResults().entrySet();
@ -129,14 +132,96 @@ public class RankEvalRequestIT extends ESIntegTestCase {
// test that a different window size k affects the result // test that a different window size k affects the result
metric = new PrecisionAtK(1, false, 3); metric = new PrecisionAtK(1, false, 3);
task = new RankEvalSpec(specifications, metric); task = new RankEvalSpec(specifications, metric);
task.addIndices(indices); task.addIndices(Collections.singletonList("test"));
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task); builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
// if we look only at top 3 documente, the expected P@3 for the first query is
// 2/3 and the expected Prec@ for the second is 1/3, divided by 2 to get the average
expectedPrecision = (1.0 / 3.0 + 2.0 / 3.0) / 2.0; expectedPrecision = (1.0 / 3.0 + 2.0 / 3.0) / 2.0;
assertEquals(0.5, response.getEvaluationResult(), Double.MIN_VALUE); assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
}
/**
* This test assumes we are using the same ratings as in {@link DiscountedCumulativeGainTests#testDCGAt()}.
* See details in that test case for how the expected values are calculated
*/
public void testDCGRequest() {
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
testQuery.sort("_id");
List<RatedRequest> specifications = new ArrayList<>();
List<RatedDocument> ratedDocs = Arrays.asList(
new RatedDocument("test", "1", 3),
new RatedDocument("test", "2", 2),
new RatedDocument("test", "3", 3),
new RatedDocument("test", "4", 0),
new RatedDocument("test", "5", 1),
new RatedDocument("test", "6", 2));
specifications.add(new RatedRequest("amsterdam_query", ratedDocs, testQuery));
DiscountedCumulativeGain metric = new DiscountedCumulativeGain(false, null, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(DiscountedCumulativeGainTests.EXPECTED_DCG, response.getEvaluationResult(), Double.MIN_VALUE);
// test that a different window size k affects the result
metric = new DiscountedCumulativeGain(false, null, 3);
task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(12.392789260714371, response.getEvaluationResult(), Double.MIN_VALUE);
}
public void testMRRRequest() {
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
testQuery.sort("_id");
List<RatedRequest> specifications = new ArrayList<>();
specifications.add(new RatedRequest("amsterdam_query", createRelevant("5"), testQuery));
specifications.add(new RatedRequest("berlin_query", createRelevant("1"), testQuery));
MeanReciprocalRank metric = new MeanReciprocalRank(1, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
// the expected reciprocal rank for the amsterdam_query is 1/5
// the expected reciprocal rank for the berlin_query is 1/1
// dividing by 2 to get the average
double expectedMRR = (1.0 / 1.0 + 1.0 / 5.0) / 2.0;
assertEquals(expectedMRR, response.getEvaluationResult(), 0.0);
// test that a different window size k affects the result
metric = new MeanReciprocalRank(1, 3);
task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
// limiting to top 3 results, the amsterdam_query has no relevant document in it
// the reciprocal rank for the berlin_query is 1/1
// dividing by 2 to get the average
expectedMRR = (1.0/ 1.0) / 2.0;
assertEquals(expectedMRR, response.getEvaluationResult(), 0.0);
} }
/** /**
@ -162,16 +247,13 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK()); RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());
task.addIndices(indices); task.addIndices(indices);
try (Client client = client()) { RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client, RankEvalAction.INSTANCE, new RankEvalRequest()); builder.setRankEvalSpec(task);
builder.setRankEvalSpec(task);
RankEvalResponse response = client.execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(1, response.getFailures().size()); assertEquals(1, response.getFailures().size());
ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query")); ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"", assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"", rootCauses[0].getCause().toString());
rootCauses[0].getCause().toString());
}
} }
private static List<RatedDocument> createRelevant(String... docs) { private static List<RatedDocument> createRelevant(String... docs) {