Add search window parameter k to MRR and DCG metric (#27595)

This commit is contained in:
Christoph Büscher 2017-12-04 10:54:03 +01:00 committed by GitHub
parent 35688f6441
commit 72d0de4197
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 248 additions and 69 deletions

View File

@ -33,21 +33,29 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
/**
* Metric implementing Discounted Cumulative Gain (https://en.wikipedia.org/wiki/Discounted_cumulative_gain).<br>
* Metric implementing Discounted Cumulative Gain.
* The `normalize` parameter can be set to calculate the normalized NDCG (set to <tt>false</tt> by default).<br>
* The optional `unknown_doc_rating` parameter can be used to specify a default rating for unlabeled documents.
* @see <a href="https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Discounted_Cumulative_Gain">Discounted Cumulative Gain</a><br>
*/
public class DiscountedCumulativeGain implements EvaluationMetric {
/** If set to true, the dcg will be normalized (ndcg) */
private final boolean normalize;
/** the default search window size */
private static final int DEFAULT_K = 10;
/** the search window size */
private final int k;
/**
* Optional. If set, this will be the rating for docs that are unrated in the ranking evaluation request
*/
@ -57,7 +65,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
private static final double LOG2 = Math.log(2.0);
public DiscountedCumulativeGain() {
this(false, null);
this(false, null, DEFAULT_K);
}
/**
@ -65,23 +73,27 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
* If set to true, dcg will be normalized (ndcg) See
* https://en.wikipedia.org/wiki/Discounted_cumulative_gain
* @param unknownDocRating
* the rating for docs the user hasn't supplied an explicit
* the rating for documents the user hasn't supplied an explicit
* rating for
* @param k the search window size all request use.
*/
public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating) {
public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating, int k) {
this.normalize = normalize;
this.unknownDocRating = unknownDocRating;
this.k = k;
}
DiscountedCumulativeGain(StreamInput in) throws IOException {
normalize = in.readBoolean();
unknownDocRating = in.readOptionalVInt();
k = in.readVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(normalize);
out.writeOptionalVInt(unknownDocRating);
out.writeVInt(k);
}
@Override
@ -89,13 +101,14 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
return NAME;
}
/**
* check whether this metric computes only dcg or "normalized" ndcg
*/
public boolean getNormalize() {
boolean getNormalize() {
return this.normalize;
}
int getK() {
return this.k;
}
/**
* get the rating used for unrated documents
*/
@ -103,6 +116,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
return this.unknownDocRating;
}
@Override
public Optional<Integer> forcedSearchSize() {
return Optional.of(k);
}
@Override
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
List<RatedDocument> ratedDocs) {
@ -142,17 +161,21 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
return dcg;
}
private static final ParseField K_FIELD = new ParseField("k");
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at",
args -> {
Boolean normalized = (Boolean) args[0];
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1]);
Integer optK = (Integer) args[2];
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1],
optK == null ? DEFAULT_K : optK);
});
static {
PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD);
PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD);
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
}
public static DiscountedCumulativeGain fromXContent(XContentParser parser) {
@ -167,6 +190,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
if (unknownDocRating != null) {
builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating);
}
builder.field(K_FIELD.getPreferredName(), this.k);
builder.endObject();
builder.endObject();
return builder;
@ -182,11 +206,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
}
DiscountedCumulativeGain other = (DiscountedCumulativeGain) obj;
return Objects.equals(normalize, other.normalize)
&& Objects.equals(unknownDocRating, other.unknownDocRating);
&& Objects.equals(unknownDocRating, other.unknownDocRating)
&& Objects.equals(k, other.k);
}
@Override
public final int hashCode() {
return Objects.hash(normalize, unknownDocRating);
return Objects.hash(normalize, unknownDocRating, k);
}
}

View File

@ -42,37 +42,57 @@ import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRati
*/
public class MeanReciprocalRank implements EvaluationMetric {
private static final int DEFAULT_RATING_THRESHOLD = 1;
public static final String NAME = "mean_reciprocal_rank";
/** ratings equal or above this value will be considered relevant. */
private static final int DEFAULT_RATING_THRESHOLD = 1;
private static final int DEFAULT_K = 10;
/** the search window size */
private final int k;
/** ratings equal or above this value will be considered relevant */
private final int relevantRatingThreshhold;
public MeanReciprocalRank() {
this(DEFAULT_RATING_THRESHOLD);
this(DEFAULT_RATING_THRESHOLD, DEFAULT_K);
}
MeanReciprocalRank(StreamInput in) throws IOException {
this.relevantRatingThreshhold = in.readVInt();
this.k = in.readVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(relevantRatingThreshhold);
out.writeVInt(this.relevantRatingThreshhold);
out.writeVInt(this.k);
}
/**
* Metric implementing Mean Reciprocal Rank (https://en.wikipedia.org/wiki/Mean_reciprocal_rank).<br>
* @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevalnt". Defaults to 1.
* @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevant". Defaults to 1.
* @param k the search window size all request use.
*/
public MeanReciprocalRank(int relevantRatingThreshold) {
public MeanReciprocalRank(int relevantRatingThreshold, int k) {
if (relevantRatingThreshold < 0) {
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
}
if (k <= 0) {
throw new IllegalArgumentException("Window size k must be positive.");
}
this.k = k;
this.relevantRatingThreshhold = relevantRatingThreshold;
}
int getK() {
return this.k;
}
@Override
public Optional<Integer> forcedSearchSize() {
return Optional.of(k);
}
@Override
public String getWriteableName() {
return NAME;
@ -113,18 +133,18 @@ public class MeanReciprocalRank implements EvaluationMetric {
}
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
private static final ParseField K_FIELD = new ParseField("k");
private static final ConstructingObjectParser<MeanReciprocalRank, Void> PARSER = new ConstructingObjectParser<>("reciprocal_rank",
args -> {
Integer optionalThreshold = (Integer) args[0];
if (optionalThreshold == null) {
return new MeanReciprocalRank();
} else {
return new MeanReciprocalRank(optionalThreshold);
}
Integer optionalK = (Integer) args[1];
return new MeanReciprocalRank(optionalThreshold == null ? DEFAULT_RATING_THRESHOLD : optionalThreshold,
optionalK == null ? DEFAULT_K : optionalK);
});
static {
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
}
public static MeanReciprocalRank fromXContent(XContentParser parser) {
@ -136,6 +156,7 @@ public class MeanReciprocalRank implements EvaluationMetric {
builder.startObject();
builder.startObject(NAME);
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
builder.field(K_FIELD.getPreferredName(), this.k);
builder.endObject();
builder.endObject();
return builder;
@ -150,12 +171,13 @@ public class MeanReciprocalRank implements EvaluationMetric {
return false;
}
MeanReciprocalRank other = (MeanReciprocalRank) obj;
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold);
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold)
&& Objects.equals(k, other.k);
}
@Override
public final int hashCode() {
return Objects.hash(relevantRatingThreshhold);
return Objects.hash(relevantRatingThreshhold, k);
}
static class Breakdown implements MetricDetails {

View File

@ -96,7 +96,7 @@ public class PrecisionAtK implements EvaluationMetric {
Integer k = (Integer) args[2];
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
ignoreUnlabeled == null ? false : ignoreUnlabeled,
k == null ? 10 : k);
k == null ? DEFAULT_K : k);
});
static {
@ -111,6 +111,10 @@ public class PrecisionAtK implements EvaluationMetric {
k = in.readVInt();
}
int getK() {
return this.k;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(relevantRatingThreshhold);

View File

@ -42,13 +42,18 @@ import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashC
public class DiscountedCumulativeGainTests extends ESTestCase {
static final double EXPECTED_DCG = 13.84826362927298;
static final double EXPECTED_IDCG = 14.595390756454922;
static final double EXPECTED_NDCG = EXPECTED_DCG / EXPECTED_IDCG;
private static final double DELTA = 10E-16;
/**
* Assuming the docs are ranked in the following order:
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* -------------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 1 | 3 | 7.0 | 1.0 | 7.0 | 7.0 | 
* 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5
* 4 | 0 | 0.0 | 2.321928094887362 | 0.0
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
@ -66,7 +71,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
hits[i].shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0, null));
}
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
assertEquals(13.84826362927298, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
assertEquals(EXPECTED_DCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
/**
* Check with normalization: to get the maximal possible dcg, sort documents by
@ -83,8 +88,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
*
* idcg = 14.595390756454922 (sum of last column)
*/
dcg = new DiscountedCumulativeGain(true, null);
assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
dcg = new DiscountedCumulativeGain(true, null, 10);
assertEquals(EXPECTED_NDCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
}
/**
@ -117,7 +122,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
}
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
EvalQueryQuality result = dcg.evaluate("id", hits, rated);
assertEquals(12.779642067948913, result.getQualityLevel(), 0.00001);
assertEquals(12.779642067948913, result.getQualityLevel(), DELTA);
assertEquals(2, filterUnknownDocuments(result.getHitsAndRatings()).size());
/**
@ -135,8 +140,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
*
* idcg = 13.347184833073591 (sum of last column)
*/
dcg = new DiscountedCumulativeGain(true, null);
assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
dcg = new DiscountedCumulativeGain(true, null, 10);
assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
}
/**
@ -174,7 +179,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
}
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
EvalQueryQuality result = dcg.evaluate("id", hits, ratedDocs);
assertEquals(12.392789260714371, result.getQualityLevel(), 0.00001);
assertEquals(12.392789260714371, result.getQualityLevel(), DELTA);
assertEquals(1, filterUnknownDocuments(result.getHitsAndRatings()).size());
/**
@ -193,16 +198,27 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
*
* idcg = 13.347184833073591 (sum of last column)
*/
dcg = new DiscountedCumulativeGain(true, null);
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001);
dcg = new DiscountedCumulativeGain(true, null, 10);
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), DELTA);
}
public void testParseFromXContent() throws IOException {
String xContent = " { \"unknown_doc_rating\": 2, \"normalize\": true }";
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true, \"k\" : 15 }", 2, true, 15);
assertParsedCorrect("{ \"normalize\": false, \"k\" : 15 }", null, false, 15);
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"k\" : 15 }", 2, false, 15);
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true }", 2, true, 10);
assertParsedCorrect("{ \"normalize\": true }", null, true, 10);
assertParsedCorrect("{ \"k\": 23 }", null, false, 23);
assertParsedCorrect("{ \"unknown_doc_rating\": 2 }", 2, false, 10);
}
private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRating, boolean expectedNormalize, int expectedK)
throws IOException {
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser);
assertEquals(2, dcgAt.getUnknownDocRating().intValue());
assertEquals(true, dcgAt.getNormalize());
assertEquals(expectedUnknownDocRating, dcgAt.getUnknownDocRating());
assertEquals(expectedNormalize, dcgAt.getNormalize());
assertEquals(expectedK, dcgAt.getK());
}
}
@ -210,7 +226,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
boolean normalize = randomBoolean();
Integer unknownDocRating = new Integer(randomIntBetween(0, 1000));
return new DiscountedCumulativeGain(normalize, unknownDocRating);
return new DiscountedCumulativeGain(normalize, unknownDocRating, 10);
}
public void testXContentRoundtrip() throws IOException {
@ -238,16 +254,22 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
public void testEqualsAndHash() throws IOException {
checkEqualsAndHashCode(createTestItem(), original -> {
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating());
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(), original.getK());
}, DiscountedCumulativeGainTests::mutateTestItem);
}
private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) {
if (randomBoolean()) {
return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating());
} else {
switch (randomIntBetween(0, 2)) {
case 0:
return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating(), original.getK());
case 1:
return new DiscountedCumulativeGain(original.getNormalize(),
randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)));
randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)), original.getK());
case 2:
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(),
randomValueOtherThan(original.getK(), () -> randomIntBetween(1, 10)));
default:
throw new IllegalArgumentException("mutation variant not allowed");
}
}
}

View File

@ -48,12 +48,28 @@ public class MeanReciprocalRankTests extends ESTestCase {
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(1, mrr.getRelevantRatingThreshold());
assertEquals(10, mrr.getK());
}
xContent = "{ \"relevant_rating_threshold\": 2 }";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(2, mrr.getRelevantRatingThreshold());
assertEquals(10, mrr.getK());
}
xContent = "{ \"relevant_rating_threshold\": 2, \"k\" : 15 }";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(2, mrr.getRelevantRatingThreshold());
assertEquals(15, mrr.getK());
}
xContent = "{ \"k\" : 15 }";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(1, mrr.getRelevantRatingThreshold());
assertEquals(15, mrr.getK());
}
}
@ -116,7 +132,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
rated.add(new RatedDocument("test", "4", 4));
SearchHit[] hits = createSearchHits(0, 5, "test");
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2);
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2, 10);
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated);
assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001);
assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
@ -167,7 +183,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
}
static MeanReciprocalRank createTestItem() {
return new MeanReciprocalRank(randomIntBetween(0, 20));
return new MeanReciprocalRank(randomIntBetween(0, 20), randomIntBetween(1, 20));
}
public void testSerialization() throws IOException {
@ -184,14 +200,22 @@ public class MeanReciprocalRankTests extends ESTestCase {
}
private static MeanReciprocalRank copy(MeanReciprocalRank testItem) {
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold());
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK());
}
private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) {
return new MeanReciprocalRank(randomValueOtherThan(testItem.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)));
if (randomBoolean()) {
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold() + 1, testItem.getK());
} else {
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK() + 1);
}
}
public void testInvalidRelevantThreshold() {
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1));
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1, 1));
}
public void testInvalidK() {
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(1, -1));
}
}

View File

@ -20,18 +20,17 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.client.Client;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.junit.Before;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
@ -64,13 +63,16 @@ public class RankEvalRequestIT extends ESIntegTestCase {
refresh();
}
/**
* Test cases retrieves all six documents indexed above. The first part checks the Prec@10 calculation where
* all unlabeled docs are treated as "unrelevant". We average Prec@ metric across two search use cases, the
* first one that labels 4 out of the 6 documents as relevant, the second one with only one relevant document.
*/
public void testPrecisionAtRequest() {
List<String> indices = Arrays.asList(new String[] { "test" });
List<RatedRequest> specifications = new ArrayList<>();
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
testQuery.sort(FieldSortBuilder.DOC_FIELD_NAME);
testQuery.sort("_id");
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query",
createRelevant("2", "3", "4", "5"), testQuery);
amsterdamRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
@ -79,12 +81,11 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("1"),
testQuery);
berlinRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
specifications.add(berlinRequest);
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
task.addIndices(indices);
task.addIndices(Collections.singletonList("test"));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
RankEvalAction.INSTANCE, new RankEvalRequest());
@ -92,6 +93,8 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
.actionGet();
// the expected Prec@ for the first query is 4/6 and the expected Prec@ for the
// second is 1/6, divided by 2 to get the average
double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0;
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
Set<Entry<String, EvalQueryQuality>> entrySet = response.getPartialResults().entrySet();
@ -129,14 +132,96 @@ public class RankEvalRequestIT extends ESIntegTestCase {
// test that a different window size k affects the result
metric = new PrecisionAtK(1, false, 3);
task = new RankEvalSpec(specifications, metric);
task.addIndices(indices);
task.addIndices(Collections.singletonList("test"));
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
// if we look only at top 3 documente, the expected P@3 for the first query is
// 2/3 and the expected Prec@ for the second is 1/3, divided by 2 to get the average
expectedPrecision = (1.0 / 3.0 + 2.0 / 3.0) / 2.0;
assertEquals(0.5, response.getEvaluationResult(), Double.MIN_VALUE);
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
}
/**
* This test assumes we are using the same ratings as in {@link DiscountedCumulativeGainTests#testDCGAt()}.
* See details in that test case for how the expected values are calculated
*/
public void testDCGRequest() {
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
testQuery.sort("_id");
List<RatedRequest> specifications = new ArrayList<>();
List<RatedDocument> ratedDocs = Arrays.asList(
new RatedDocument("test", "1", 3),
new RatedDocument("test", "2", 2),
new RatedDocument("test", "3", 3),
new RatedDocument("test", "4", 0),
new RatedDocument("test", "5", 1),
new RatedDocument("test", "6", 2));
specifications.add(new RatedRequest("amsterdam_query", ratedDocs, testQuery));
DiscountedCumulativeGain metric = new DiscountedCumulativeGain(false, null, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(DiscountedCumulativeGainTests.EXPECTED_DCG, response.getEvaluationResult(), Double.MIN_VALUE);
// test that a different window size k affects the result
metric = new DiscountedCumulativeGain(false, null, 3);
task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(12.392789260714371, response.getEvaluationResult(), Double.MIN_VALUE);
}
public void testMRRRequest() {
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
testQuery.sort("_id");
List<RatedRequest> specifications = new ArrayList<>();
specifications.add(new RatedRequest("amsterdam_query", createRelevant("5"), testQuery));
specifications.add(new RatedRequest("berlin_query", createRelevant("1"), testQuery));
MeanReciprocalRank metric = new MeanReciprocalRank(1, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
// the expected reciprocal rank for the amsterdam_query is 1/5
// the expected reciprocal rank for the berlin_query is 1/1
// dividing by 2 to get the average
double expectedMRR = (1.0 / 1.0 + 1.0 / 5.0) / 2.0;
assertEquals(expectedMRR, response.getEvaluationResult(), 0.0);
// test that a different window size k affects the result
metric = new MeanReciprocalRank(1, 3);
task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
// limiting to top 3 results, the amsterdam_query has no relevant document in it
// the reciprocal rank for the berlin_query is 1/1
// dividing by 2 to get the average
expectedMRR = (1.0/ 1.0) / 2.0;
assertEquals(expectedMRR, response.getEvaluationResult(), 0.0);
}
/**
@ -162,16 +247,13 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());
task.addIndices(indices);
try (Client client = client()) {
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client, RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalResponse response = client.execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(1, response.getFailures().size());
ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"",
rootCauses[0].getCause().toString());
}
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(1, response.getFailures().size());
ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"", rootCauses[0].getCause().toString());
}
private static List<RatedDocument> createRelevant(String... docs) {