Add search window parameter k to MRR and DCG metric (#27595)
This commit is contained in:
parent
35688f6441
commit
72d0de4197
|
@ -33,21 +33,29 @@ import java.util.Collections;
|
|||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
|
||||
|
||||
/**
|
||||
* Metric implementing Discounted Cumulative Gain (https://en.wikipedia.org/wiki/Discounted_cumulative_gain).<br>
|
||||
* Metric implementing Discounted Cumulative Gain.
|
||||
* The `normalize` parameter can be set to calculate the normalized NDCG (set to <tt>false</tt> by default).<br>
|
||||
* The optional `unknown_doc_rating` parameter can be used to specify a default rating for unlabeled documents.
|
||||
* @see <a href="https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Discounted_Cumulative_Gain">Discounted Cumulative Gain</a><br>
|
||||
*/
|
||||
public class DiscountedCumulativeGain implements EvaluationMetric {
|
||||
|
||||
/** If set to true, the dcg will be normalized (ndcg) */
|
||||
private final boolean normalize;
|
||||
|
||||
/** the default search window size */
|
||||
private static final int DEFAULT_K = 10;
|
||||
|
||||
/** the search window size */
|
||||
private final int k;
|
||||
|
||||
/**
|
||||
* Optional. If set, this will be the rating for docs that are unrated in the ranking evaluation request
|
||||
*/
|
||||
|
@ -57,7 +65,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
|||
private static final double LOG2 = Math.log(2.0);
|
||||
|
||||
public DiscountedCumulativeGain() {
|
||||
this(false, null);
|
||||
this(false, null, DEFAULT_K);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -65,23 +73,27 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
|||
* If set to true, dcg will be normalized (ndcg) See
|
||||
* https://en.wikipedia.org/wiki/Discounted_cumulative_gain
|
||||
* @param unknownDocRating
|
||||
* the rating for docs the user hasn't supplied an explicit
|
||||
* the rating for documents the user hasn't supplied an explicit
|
||||
* rating for
|
||||
* @param k the search window size all request use.
|
||||
*/
|
||||
public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating) {
|
||||
public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating, int k) {
|
||||
this.normalize = normalize;
|
||||
this.unknownDocRating = unknownDocRating;
|
||||
this.k = k;
|
||||
}
|
||||
|
||||
DiscountedCumulativeGain(StreamInput in) throws IOException {
|
||||
normalize = in.readBoolean();
|
||||
unknownDocRating = in.readOptionalVInt();
|
||||
k = in.readVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeBoolean(normalize);
|
||||
out.writeOptionalVInt(unknownDocRating);
|
||||
out.writeVInt(k);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -89,13 +101,14 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
|||
return NAME;
|
||||
}
|
||||
|
||||
/**
|
||||
* check whether this metric computes only dcg or "normalized" ndcg
|
||||
*/
|
||||
public boolean getNormalize() {
|
||||
boolean getNormalize() {
|
||||
return this.normalize;
|
||||
}
|
||||
|
||||
int getK() {
|
||||
return this.k;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the rating used for unrated documents
|
||||
*/
|
||||
|
@ -103,6 +116,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
|||
return this.unknownDocRating;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Optional<Integer> forcedSearchSize() {
|
||||
return Optional.of(k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
|
||||
List<RatedDocument> ratedDocs) {
|
||||
|
@ -142,17 +161,21 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
|||
return dcg;
|
||||
}
|
||||
|
||||
private static final ParseField K_FIELD = new ParseField("k");
|
||||
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
|
||||
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
|
||||
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at",
|
||||
args -> {
|
||||
Boolean normalized = (Boolean) args[0];
|
||||
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1]);
|
||||
Integer optK = (Integer) args[2];
|
||||
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1],
|
||||
optK == null ? DEFAULT_K : optK);
|
||||
});
|
||||
|
||||
static {
|
||||
PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
|
||||
}
|
||||
|
||||
public static DiscountedCumulativeGain fromXContent(XContentParser parser) {
|
||||
|
@ -167,6 +190,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
|||
if (unknownDocRating != null) {
|
||||
builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating);
|
||||
}
|
||||
builder.field(K_FIELD.getPreferredName(), this.k);
|
||||
builder.endObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -182,11 +206,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
|||
}
|
||||
DiscountedCumulativeGain other = (DiscountedCumulativeGain) obj;
|
||||
return Objects.equals(normalize, other.normalize)
|
||||
&& Objects.equals(unknownDocRating, other.unknownDocRating);
|
||||
&& Objects.equals(unknownDocRating, other.unknownDocRating)
|
||||
&& Objects.equals(k, other.k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(normalize, unknownDocRating);
|
||||
return Objects.hash(normalize, unknownDocRating, k);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,37 +42,57 @@ import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRati
|
|||
*/
|
||||
public class MeanReciprocalRank implements EvaluationMetric {
|
||||
|
||||
private static final int DEFAULT_RATING_THRESHOLD = 1;
|
||||
|
||||
public static final String NAME = "mean_reciprocal_rank";
|
||||
|
||||
/** ratings equal or above this value will be considered relevant. */
|
||||
private static final int DEFAULT_RATING_THRESHOLD = 1;
|
||||
private static final int DEFAULT_K = 10;
|
||||
|
||||
/** the search window size */
|
||||
private final int k;
|
||||
|
||||
/** ratings equal or above this value will be considered relevant */
|
||||
private final int relevantRatingThreshhold;
|
||||
|
||||
public MeanReciprocalRank() {
|
||||
this(DEFAULT_RATING_THRESHOLD);
|
||||
this(DEFAULT_RATING_THRESHOLD, DEFAULT_K);
|
||||
}
|
||||
|
||||
MeanReciprocalRank(StreamInput in) throws IOException {
|
||||
this.relevantRatingThreshhold = in.readVInt();
|
||||
this.k = in.readVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(relevantRatingThreshhold);
|
||||
out.writeVInt(this.relevantRatingThreshhold);
|
||||
out.writeVInt(this.k);
|
||||
}
|
||||
|
||||
/**
|
||||
* Metric implementing Mean Reciprocal Rank (https://en.wikipedia.org/wiki/Mean_reciprocal_rank).<br>
|
||||
* @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevalnt". Defaults to 1.
|
||||
* @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevant". Defaults to 1.
|
||||
* @param k the search window size all request use.
|
||||
*/
|
||||
public MeanReciprocalRank(int relevantRatingThreshold) {
|
||||
public MeanReciprocalRank(int relevantRatingThreshold, int k) {
|
||||
if (relevantRatingThreshold < 0) {
|
||||
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
|
||||
}
|
||||
if (k <= 0) {
|
||||
throw new IllegalArgumentException("Window size k must be positive.");
|
||||
}
|
||||
this.k = k;
|
||||
this.relevantRatingThreshhold = relevantRatingThreshold;
|
||||
}
|
||||
|
||||
int getK() {
|
||||
return this.k;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<Integer> forcedSearchSize() {
|
||||
return Optional.of(k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
|
@ -113,18 +133,18 @@ public class MeanReciprocalRank implements EvaluationMetric {
|
|||
}
|
||||
|
||||
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
|
||||
private static final ParseField K_FIELD = new ParseField("k");
|
||||
private static final ConstructingObjectParser<MeanReciprocalRank, Void> PARSER = new ConstructingObjectParser<>("reciprocal_rank",
|
||||
args -> {
|
||||
Integer optionalThreshold = (Integer) args[0];
|
||||
if (optionalThreshold == null) {
|
||||
return new MeanReciprocalRank();
|
||||
} else {
|
||||
return new MeanReciprocalRank(optionalThreshold);
|
||||
}
|
||||
Integer optionalK = (Integer) args[1];
|
||||
return new MeanReciprocalRank(optionalThreshold == null ? DEFAULT_RATING_THRESHOLD : optionalThreshold,
|
||||
optionalK == null ? DEFAULT_K : optionalK);
|
||||
});
|
||||
|
||||
static {
|
||||
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
|
||||
}
|
||||
|
||||
public static MeanReciprocalRank fromXContent(XContentParser parser) {
|
||||
|
@ -136,6 +156,7 @@ public class MeanReciprocalRank implements EvaluationMetric {
|
|||
builder.startObject();
|
||||
builder.startObject(NAME);
|
||||
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
|
||||
builder.field(K_FIELD.getPreferredName(), this.k);
|
||||
builder.endObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -150,12 +171,13 @@ public class MeanReciprocalRank implements EvaluationMetric {
|
|||
return false;
|
||||
}
|
||||
MeanReciprocalRank other = (MeanReciprocalRank) obj;
|
||||
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold);
|
||||
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold)
|
||||
&& Objects.equals(k, other.k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(relevantRatingThreshhold);
|
||||
return Objects.hash(relevantRatingThreshhold, k);
|
||||
}
|
||||
|
||||
static class Breakdown implements MetricDetails {
|
||||
|
|
|
@ -96,7 +96,7 @@ public class PrecisionAtK implements EvaluationMetric {
|
|||
Integer k = (Integer) args[2];
|
||||
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
|
||||
ignoreUnlabeled == null ? false : ignoreUnlabeled,
|
||||
k == null ? 10 : k);
|
||||
k == null ? DEFAULT_K : k);
|
||||
});
|
||||
|
||||
static {
|
||||
|
@ -111,6 +111,10 @@ public class PrecisionAtK implements EvaluationMetric {
|
|||
k = in.readVInt();
|
||||
}
|
||||
|
||||
int getK() {
|
||||
return this.k;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(relevantRatingThreshhold);
|
||||
|
|
|
@ -42,13 +42,18 @@ import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashC
|
|||
|
||||
public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||
|
||||
static final double EXPECTED_DCG = 13.84826362927298;
|
||||
static final double EXPECTED_IDCG = 14.595390756454922;
|
||||
static final double EXPECTED_NDCG = EXPECTED_DCG / EXPECTED_IDCG;
|
||||
private static final double DELTA = 10E-16;
|
||||
|
||||
/**
|
||||
* Assuming the docs are ranked in the following order:
|
||||
*
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
|
||||
* -------------------------------------------------------------------------------------------
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 |
|
||||
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 | 7.0 |
|
||||
* 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
|
||||
* 3 | 3 | 7.0 | 2.0 | 3.5
|
||||
* 4 | 0 | 0.0 | 2.321928094887362 | 0.0
|
||||
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
|
||||
|
@ -66,7 +71,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
|||
hits[i].shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0, null));
|
||||
}
|
||||
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
||||
assertEquals(13.84826362927298, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
|
||||
assertEquals(EXPECTED_DCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
|
||||
|
||||
/**
|
||||
* Check with normalization: to get the maximal possible dcg, sort documents by
|
||||
|
@ -83,8 +88,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
|||
*
|
||||
* idcg = 14.595390756454922 (sum of last column)
|
||||
*/
|
||||
dcg = new DiscountedCumulativeGain(true, null);
|
||||
assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
|
||||
dcg = new DiscountedCumulativeGain(true, null, 10);
|
||||
assertEquals(EXPECTED_NDCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -117,7 +122,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
|||
}
|
||||
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
||||
EvalQueryQuality result = dcg.evaluate("id", hits, rated);
|
||||
assertEquals(12.779642067948913, result.getQualityLevel(), 0.00001);
|
||||
assertEquals(12.779642067948913, result.getQualityLevel(), DELTA);
|
||||
assertEquals(2, filterUnknownDocuments(result.getHitsAndRatings()).size());
|
||||
|
||||
/**
|
||||
|
@ -135,8 +140,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
|||
*
|
||||
* idcg = 13.347184833073591 (sum of last column)
|
||||
*/
|
||||
dcg = new DiscountedCumulativeGain(true, null);
|
||||
assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
|
||||
dcg = new DiscountedCumulativeGain(true, null, 10);
|
||||
assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -174,7 +179,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
|||
}
|
||||
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
||||
EvalQueryQuality result = dcg.evaluate("id", hits, ratedDocs);
|
||||
assertEquals(12.392789260714371, result.getQualityLevel(), 0.00001);
|
||||
assertEquals(12.392789260714371, result.getQualityLevel(), DELTA);
|
||||
assertEquals(1, filterUnknownDocuments(result.getHitsAndRatings()).size());
|
||||
|
||||
/**
|
||||
|
@ -193,16 +198,27 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
|||
*
|
||||
* idcg = 13.347184833073591 (sum of last column)
|
||||
*/
|
||||
dcg = new DiscountedCumulativeGain(true, null);
|
||||
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001);
|
||||
dcg = new DiscountedCumulativeGain(true, null, 10);
|
||||
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), DELTA);
|
||||
}
|
||||
|
||||
public void testParseFromXContent() throws IOException {
|
||||
String xContent = " { \"unknown_doc_rating\": 2, \"normalize\": true }";
|
||||
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true, \"k\" : 15 }", 2, true, 15);
|
||||
assertParsedCorrect("{ \"normalize\": false, \"k\" : 15 }", null, false, 15);
|
||||
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"k\" : 15 }", 2, false, 15);
|
||||
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true }", 2, true, 10);
|
||||
assertParsedCorrect("{ \"normalize\": true }", null, true, 10);
|
||||
assertParsedCorrect("{ \"k\": 23 }", null, false, 23);
|
||||
assertParsedCorrect("{ \"unknown_doc_rating\": 2 }", 2, false, 10);
|
||||
}
|
||||
|
||||
private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRating, boolean expectedNormalize, int expectedK)
|
||||
throws IOException {
|
||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||
DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser);
|
||||
assertEquals(2, dcgAt.getUnknownDocRating().intValue());
|
||||
assertEquals(true, dcgAt.getNormalize());
|
||||
assertEquals(expectedUnknownDocRating, dcgAt.getUnknownDocRating());
|
||||
assertEquals(expectedNormalize, dcgAt.getNormalize());
|
||||
assertEquals(expectedK, dcgAt.getK());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -210,7 +226,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
|||
boolean normalize = randomBoolean();
|
||||
Integer unknownDocRating = new Integer(randomIntBetween(0, 1000));
|
||||
|
||||
return new DiscountedCumulativeGain(normalize, unknownDocRating);
|
||||
return new DiscountedCumulativeGain(normalize, unknownDocRating, 10);
|
||||
}
|
||||
|
||||
public void testXContentRoundtrip() throws IOException {
|
||||
|
@ -238,16 +254,22 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
|||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
checkEqualsAndHashCode(createTestItem(), original -> {
|
||||
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating());
|
||||
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(), original.getK());
|
||||
}, DiscountedCumulativeGainTests::mutateTestItem);
|
||||
}
|
||||
|
||||
private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) {
|
||||
if (randomBoolean()) {
|
||||
return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating());
|
||||
} else {
|
||||
switch (randomIntBetween(0, 2)) {
|
||||
case 0:
|
||||
return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating(), original.getK());
|
||||
case 1:
|
||||
return new DiscountedCumulativeGain(original.getNormalize(),
|
||||
randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)));
|
||||
randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)), original.getK());
|
||||
case 2:
|
||||
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(),
|
||||
randomValueOtherThan(original.getK(), () -> randomIntBetween(1, 10)));
|
||||
default:
|
||||
throw new IllegalArgumentException("mutation variant not allowed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,12 +48,28 @@ public class MeanReciprocalRankTests extends ESTestCase {
|
|||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
||||
assertEquals(1, mrr.getRelevantRatingThreshold());
|
||||
assertEquals(10, mrr.getK());
|
||||
}
|
||||
|
||||
xContent = "{ \"relevant_rating_threshold\": 2 }";
|
||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
||||
assertEquals(2, mrr.getRelevantRatingThreshold());
|
||||
assertEquals(10, mrr.getK());
|
||||
}
|
||||
|
||||
xContent = "{ \"relevant_rating_threshold\": 2, \"k\" : 15 }";
|
||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
||||
assertEquals(2, mrr.getRelevantRatingThreshold());
|
||||
assertEquals(15, mrr.getK());
|
||||
}
|
||||
|
||||
xContent = "{ \"k\" : 15 }";
|
||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
||||
assertEquals(1, mrr.getRelevantRatingThreshold());
|
||||
assertEquals(15, mrr.getK());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,7 +132,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
|
|||
rated.add(new RatedDocument("test", "4", 4));
|
||||
SearchHit[] hits = createSearchHits(0, 5, "test");
|
||||
|
||||
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2);
|
||||
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2, 10);
|
||||
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated);
|
||||
assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001);
|
||||
assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
|
||||
|
@ -167,7 +183,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
|
|||
}
|
||||
|
||||
static MeanReciprocalRank createTestItem() {
|
||||
return new MeanReciprocalRank(randomIntBetween(0, 20));
|
||||
return new MeanReciprocalRank(randomIntBetween(0, 20), randomIntBetween(1, 20));
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
|
@ -184,14 +200,22 @@ public class MeanReciprocalRankTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private static MeanReciprocalRank copy(MeanReciprocalRank testItem) {
|
||||
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold());
|
||||
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK());
|
||||
}
|
||||
|
||||
private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) {
|
||||
return new MeanReciprocalRank(randomValueOtherThan(testItem.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)));
|
||||
if (randomBoolean()) {
|
||||
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold() + 1, testItem.getK());
|
||||
} else {
|
||||
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
public void testInvalidRelevantThreshold() {
|
||||
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1));
|
||||
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1, 1));
|
||||
}
|
||||
|
||||
public void testInvalidK() {
|
||||
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(1, -1));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,18 +20,17 @@
|
|||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
|
@ -64,13 +63,16 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
refresh();
|
||||
}
|
||||
|
||||
/**
|
||||
* Test cases retrieves all six documents indexed above. The first part checks the Prec@10 calculation where
|
||||
* all unlabeled docs are treated as "unrelevant". We average Prec@ metric across two search use cases, the
|
||||
* first one that labels 4 out of the 6 documents as relevant, the second one with only one relevant document.
|
||||
*/
|
||||
public void testPrecisionAtRequest() {
|
||||
List<String> indices = Arrays.asList(new String[] { "test" });
|
||||
|
||||
List<RatedRequest> specifications = new ArrayList<>();
|
||||
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
||||
testQuery.query(new MatchAllQueryBuilder());
|
||||
testQuery.sort(FieldSortBuilder.DOC_FIELD_NAME);
|
||||
testQuery.sort("_id");
|
||||
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query",
|
||||
createRelevant("2", "3", "4", "5"), testQuery);
|
||||
amsterdamRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
|
||||
|
@ -79,12 +81,11 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("1"),
|
||||
testQuery);
|
||||
berlinRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
|
||||
|
||||
specifications.add(berlinRequest);
|
||||
|
||||
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
|
||||
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||
task.addIndices(indices);
|
||||
task.addIndices(Collections.singletonList("test"));
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
|
||||
RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
|
@ -92,6 +93,8 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
|
||||
.actionGet();
|
||||
// the expected Prec@ for the first query is 4/6 and the expected Prec@ for the
|
||||
// second is 1/6, divided by 2 to get the average
|
||||
double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0;
|
||||
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||
Set<Entry<String, EvalQueryQuality>> entrySet = response.getPartialResults().entrySet();
|
||||
|
@ -129,14 +132,96 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
// test that a different window size k affects the result
|
||||
metric = new PrecisionAtK(1, false, 3);
|
||||
task = new RankEvalSpec(specifications, metric);
|
||||
task.addIndices(indices);
|
||||
task.addIndices(Collections.singletonList("test"));
|
||||
|
||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
// if we look only at top 3 documente, the expected P@3 for the first query is
|
||||
// 2/3 and the expected Prec@ for the second is 1/3, divided by 2 to get the average
|
||||
expectedPrecision = (1.0 / 3.0 + 2.0 / 3.0) / 2.0;
|
||||
assertEquals(0.5, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||
}
|
||||
|
||||
/**
|
||||
* This test assumes we are using the same ratings as in {@link DiscountedCumulativeGainTests#testDCGAt()}.
|
||||
* See details in that test case for how the expected values are calculated
|
||||
*/
|
||||
public void testDCGRequest() {
|
||||
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
||||
testQuery.query(new MatchAllQueryBuilder());
|
||||
testQuery.sort("_id");
|
||||
|
||||
List<RatedRequest> specifications = new ArrayList<>();
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(
|
||||
new RatedDocument("test", "1", 3),
|
||||
new RatedDocument("test", "2", 2),
|
||||
new RatedDocument("test", "3", 3),
|
||||
new RatedDocument("test", "4", 0),
|
||||
new RatedDocument("test", "5", 1),
|
||||
new RatedDocument("test", "6", 2));
|
||||
specifications.add(new RatedRequest("amsterdam_query", ratedDocs, testQuery));
|
||||
|
||||
DiscountedCumulativeGain metric = new DiscountedCumulativeGain(false, null, 10);
|
||||
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||
task.addIndices(Collections.singletonList("test"));
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
assertEquals(DiscountedCumulativeGainTests.EXPECTED_DCG, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||
|
||||
// test that a different window size k affects the result
|
||||
metric = new DiscountedCumulativeGain(false, null, 3);
|
||||
task = new RankEvalSpec(specifications, metric);
|
||||
task.addIndices(Collections.singletonList("test"));
|
||||
|
||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
assertEquals(12.392789260714371, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||
}
|
||||
|
||||
public void testMRRRequest() {
|
||||
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
||||
testQuery.query(new MatchAllQueryBuilder());
|
||||
testQuery.sort("_id");
|
||||
|
||||
List<RatedRequest> specifications = new ArrayList<>();
|
||||
specifications.add(new RatedRequest("amsterdam_query", createRelevant("5"), testQuery));
|
||||
specifications.add(new RatedRequest("berlin_query", createRelevant("1"), testQuery));
|
||||
|
||||
MeanReciprocalRank metric = new MeanReciprocalRank(1, 10);
|
||||
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||
task.addIndices(Collections.singletonList("test"));
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
// the expected reciprocal rank for the amsterdam_query is 1/5
|
||||
// the expected reciprocal rank for the berlin_query is 1/1
|
||||
// dividing by 2 to get the average
|
||||
double expectedMRR = (1.0 / 1.0 + 1.0 / 5.0) / 2.0;
|
||||
assertEquals(expectedMRR, response.getEvaluationResult(), 0.0);
|
||||
|
||||
// test that a different window size k affects the result
|
||||
metric = new MeanReciprocalRank(1, 3);
|
||||
task = new RankEvalSpec(specifications, metric);
|
||||
task.addIndices(Collections.singletonList("test"));
|
||||
|
||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
// limiting to top 3 results, the amsterdam_query has no relevant document in it
|
||||
// the reciprocal rank for the berlin_query is 1/1
|
||||
// dividing by 2 to get the average
|
||||
expectedMRR = (1.0/ 1.0) / 2.0;
|
||||
assertEquals(expectedMRR, response.getEvaluationResult(), 0.0);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -162,16 +247,13 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());
|
||||
task.addIndices(indices);
|
||||
|
||||
try (Client client = client()) {
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client, RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
RankEvalResponse response = client.execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
assertEquals(1, response.getFailures().size());
|
||||
ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
|
||||
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"",
|
||||
rootCauses[0].getCause().toString());
|
||||
}
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
assertEquals(1, response.getFailures().size());
|
||||
ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
|
||||
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"", rootCauses[0].getCause().toString());
|
||||
}
|
||||
|
||||
private static List<RatedDocument> createRelevant(String... docs) {
|
||||
|
|
Loading…
Reference in New Issue