Add search window parameter k to MRR and DCG metric (#27595)
This commit is contained in:
parent
35688f6441
commit
72d0de4197
|
@ -33,21 +33,29 @@ import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
|
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Metric implementing Discounted Cumulative Gain (https://en.wikipedia.org/wiki/Discounted_cumulative_gain).<br>
|
* Metric implementing Discounted Cumulative Gain.
|
||||||
* The `normalize` parameter can be set to calculate the normalized NDCG (set to <tt>false</tt> by default).<br>
|
* The `normalize` parameter can be set to calculate the normalized NDCG (set to <tt>false</tt> by default).<br>
|
||||||
* The optional `unknown_doc_rating` parameter can be used to specify a default rating for unlabeled documents.
|
* The optional `unknown_doc_rating` parameter can be used to specify a default rating for unlabeled documents.
|
||||||
|
* @see <a href="https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Discounted_Cumulative_Gain">Discounted Cumulative Gain</a><br>
|
||||||
*/
|
*/
|
||||||
public class DiscountedCumulativeGain implements EvaluationMetric {
|
public class DiscountedCumulativeGain implements EvaluationMetric {
|
||||||
|
|
||||||
/** If set to true, the dcg will be normalized (ndcg) */
|
/** If set to true, the dcg will be normalized (ndcg) */
|
||||||
private final boolean normalize;
|
private final boolean normalize;
|
||||||
|
|
||||||
|
/** the default search window size */
|
||||||
|
private static final int DEFAULT_K = 10;
|
||||||
|
|
||||||
|
/** the search window size */
|
||||||
|
private final int k;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optional. If set, this will be the rating for docs that are unrated in the ranking evaluation request
|
* Optional. If set, this will be the rating for docs that are unrated in the ranking evaluation request
|
||||||
*/
|
*/
|
||||||
|
@ -57,7 +65,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
||||||
private static final double LOG2 = Math.log(2.0);
|
private static final double LOG2 = Math.log(2.0);
|
||||||
|
|
||||||
public DiscountedCumulativeGain() {
|
public DiscountedCumulativeGain() {
|
||||||
this(false, null);
|
this(false, null, DEFAULT_K);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -65,23 +73,27 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
||||||
* If set to true, dcg will be normalized (ndcg) See
|
* If set to true, dcg will be normalized (ndcg) See
|
||||||
* https://en.wikipedia.org/wiki/Discounted_cumulative_gain
|
* https://en.wikipedia.org/wiki/Discounted_cumulative_gain
|
||||||
* @param unknownDocRating
|
* @param unknownDocRating
|
||||||
* the rating for docs the user hasn't supplied an explicit
|
* the rating for documents the user hasn't supplied an explicit
|
||||||
* rating for
|
* rating for
|
||||||
|
* @param k the search window size all request use.
|
||||||
*/
|
*/
|
||||||
public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating) {
|
public DiscountedCumulativeGain(boolean normalize, Integer unknownDocRating, int k) {
|
||||||
this.normalize = normalize;
|
this.normalize = normalize;
|
||||||
this.unknownDocRating = unknownDocRating;
|
this.unknownDocRating = unknownDocRating;
|
||||||
|
this.k = k;
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscountedCumulativeGain(StreamInput in) throws IOException {
|
DiscountedCumulativeGain(StreamInput in) throws IOException {
|
||||||
normalize = in.readBoolean();
|
normalize = in.readBoolean();
|
||||||
unknownDocRating = in.readOptionalVInt();
|
unknownDocRating = in.readOptionalVInt();
|
||||||
|
k = in.readVInt();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
out.writeBoolean(normalize);
|
out.writeBoolean(normalize);
|
||||||
out.writeOptionalVInt(unknownDocRating);
|
out.writeOptionalVInt(unknownDocRating);
|
||||||
|
out.writeVInt(k);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -89,13 +101,14 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
||||||
return NAME;
|
return NAME;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
boolean getNormalize() {
|
||||||
* check whether this metric computes only dcg or "normalized" ndcg
|
|
||||||
*/
|
|
||||||
public boolean getNormalize() {
|
|
||||||
return this.normalize;
|
return this.normalize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int getK() {
|
||||||
|
return this.k;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get the rating used for unrated documents
|
* get the rating used for unrated documents
|
||||||
*/
|
*/
|
||||||
|
@ -103,6 +116,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
||||||
return this.unknownDocRating;
|
return this.unknownDocRating;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Optional<Integer> forcedSearchSize() {
|
||||||
|
return Optional.of(k);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
|
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
|
||||||
List<RatedDocument> ratedDocs) {
|
List<RatedDocument> ratedDocs) {
|
||||||
|
@ -142,17 +161,21 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
||||||
return dcg;
|
return dcg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final ParseField K_FIELD = new ParseField("k");
|
||||||
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
|
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
|
||||||
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
|
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
|
||||||
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at",
|
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at",
|
||||||
args -> {
|
args -> {
|
||||||
Boolean normalized = (Boolean) args[0];
|
Boolean normalized = (Boolean) args[0];
|
||||||
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1]);
|
Integer optK = (Integer) args[2];
|
||||||
|
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1],
|
||||||
|
optK == null ? DEFAULT_K : optK);
|
||||||
});
|
});
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD);
|
PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD);
|
||||||
PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD);
|
PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD);
|
||||||
|
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static DiscountedCumulativeGain fromXContent(XContentParser parser) {
|
public static DiscountedCumulativeGain fromXContent(XContentParser parser) {
|
||||||
|
@ -167,6 +190,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
||||||
if (unknownDocRating != null) {
|
if (unknownDocRating != null) {
|
||||||
builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating);
|
builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating);
|
||||||
}
|
}
|
||||||
|
builder.field(K_FIELD.getPreferredName(), this.k);
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -182,11 +206,12 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
DiscountedCumulativeGain other = (DiscountedCumulativeGain) obj;
|
DiscountedCumulativeGain other = (DiscountedCumulativeGain) obj;
|
||||||
return Objects.equals(normalize, other.normalize)
|
return Objects.equals(normalize, other.normalize)
|
||||||
&& Objects.equals(unknownDocRating, other.unknownDocRating);
|
&& Objects.equals(unknownDocRating, other.unknownDocRating)
|
||||||
|
&& Objects.equals(k, other.k);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final int hashCode() {
|
public final int hashCode() {
|
||||||
return Objects.hash(normalize, unknownDocRating);
|
return Objects.hash(normalize, unknownDocRating, k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,37 +42,57 @@ import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRati
|
||||||
*/
|
*/
|
||||||
public class MeanReciprocalRank implements EvaluationMetric {
|
public class MeanReciprocalRank implements EvaluationMetric {
|
||||||
|
|
||||||
private static final int DEFAULT_RATING_THRESHOLD = 1;
|
|
||||||
|
|
||||||
public static final String NAME = "mean_reciprocal_rank";
|
public static final String NAME = "mean_reciprocal_rank";
|
||||||
|
|
||||||
/** ratings equal or above this value will be considered relevant. */
|
private static final int DEFAULT_RATING_THRESHOLD = 1;
|
||||||
|
private static final int DEFAULT_K = 10;
|
||||||
|
|
||||||
|
/** the search window size */
|
||||||
|
private final int k;
|
||||||
|
|
||||||
|
/** ratings equal or above this value will be considered relevant */
|
||||||
private final int relevantRatingThreshhold;
|
private final int relevantRatingThreshhold;
|
||||||
|
|
||||||
public MeanReciprocalRank() {
|
public MeanReciprocalRank() {
|
||||||
this(DEFAULT_RATING_THRESHOLD);
|
this(DEFAULT_RATING_THRESHOLD, DEFAULT_K);
|
||||||
}
|
}
|
||||||
|
|
||||||
MeanReciprocalRank(StreamInput in) throws IOException {
|
MeanReciprocalRank(StreamInput in) throws IOException {
|
||||||
this.relevantRatingThreshhold = in.readVInt();
|
this.relevantRatingThreshhold = in.readVInt();
|
||||||
|
this.k = in.readVInt();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
out.writeVInt(relevantRatingThreshhold);
|
out.writeVInt(this.relevantRatingThreshhold);
|
||||||
|
out.writeVInt(this.k);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Metric implementing Mean Reciprocal Rank (https://en.wikipedia.org/wiki/Mean_reciprocal_rank).<br>
|
* Metric implementing Mean Reciprocal Rank (https://en.wikipedia.org/wiki/Mean_reciprocal_rank).<br>
|
||||||
* @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevalnt". Defaults to 1.
|
* @param relevantRatingThreshold the rating value that a document needs to be regarded as "relevant". Defaults to 1.
|
||||||
|
* @param k the search window size all request use.
|
||||||
*/
|
*/
|
||||||
public MeanReciprocalRank(int relevantRatingThreshold) {
|
public MeanReciprocalRank(int relevantRatingThreshold, int k) {
|
||||||
if (relevantRatingThreshold < 0) {
|
if (relevantRatingThreshold < 0) {
|
||||||
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
|
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
|
||||||
}
|
}
|
||||||
|
if (k <= 0) {
|
||||||
|
throw new IllegalArgumentException("Window size k must be positive.");
|
||||||
|
}
|
||||||
|
this.k = k;
|
||||||
this.relevantRatingThreshhold = relevantRatingThreshold;
|
this.relevantRatingThreshhold = relevantRatingThreshold;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int getK() {
|
||||||
|
return this.k;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Optional<Integer> forcedSearchSize() {
|
||||||
|
return Optional.of(k);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getWriteableName() {
|
public String getWriteableName() {
|
||||||
return NAME;
|
return NAME;
|
||||||
|
@ -113,18 +133,18 @@ public class MeanReciprocalRank implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
|
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
|
||||||
|
private static final ParseField K_FIELD = new ParseField("k");
|
||||||
private static final ConstructingObjectParser<MeanReciprocalRank, Void> PARSER = new ConstructingObjectParser<>("reciprocal_rank",
|
private static final ConstructingObjectParser<MeanReciprocalRank, Void> PARSER = new ConstructingObjectParser<>("reciprocal_rank",
|
||||||
args -> {
|
args -> {
|
||||||
Integer optionalThreshold = (Integer) args[0];
|
Integer optionalThreshold = (Integer) args[0];
|
||||||
if (optionalThreshold == null) {
|
Integer optionalK = (Integer) args[1];
|
||||||
return new MeanReciprocalRank();
|
return new MeanReciprocalRank(optionalThreshold == null ? DEFAULT_RATING_THRESHOLD : optionalThreshold,
|
||||||
} else {
|
optionalK == null ? DEFAULT_K : optionalK);
|
||||||
return new MeanReciprocalRank(optionalThreshold);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
|
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
|
||||||
|
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static MeanReciprocalRank fromXContent(XContentParser parser) {
|
public static MeanReciprocalRank fromXContent(XContentParser parser) {
|
||||||
|
@ -136,6 +156,7 @@ public class MeanReciprocalRank implements EvaluationMetric {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
builder.startObject(NAME);
|
builder.startObject(NAME);
|
||||||
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
|
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
|
||||||
|
builder.field(K_FIELD.getPreferredName(), this.k);
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -150,12 +171,13 @@ public class MeanReciprocalRank implements EvaluationMetric {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
MeanReciprocalRank other = (MeanReciprocalRank) obj;
|
MeanReciprocalRank other = (MeanReciprocalRank) obj;
|
||||||
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold);
|
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold)
|
||||||
|
&& Objects.equals(k, other.k);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final int hashCode() {
|
public final int hashCode() {
|
||||||
return Objects.hash(relevantRatingThreshhold);
|
return Objects.hash(relevantRatingThreshhold, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static class Breakdown implements MetricDetails {
|
static class Breakdown implements MetricDetails {
|
||||||
|
|
|
@ -96,7 +96,7 @@ public class PrecisionAtK implements EvaluationMetric {
|
||||||
Integer k = (Integer) args[2];
|
Integer k = (Integer) args[2];
|
||||||
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
|
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
|
||||||
ignoreUnlabeled == null ? false : ignoreUnlabeled,
|
ignoreUnlabeled == null ? false : ignoreUnlabeled,
|
||||||
k == null ? 10 : k);
|
k == null ? DEFAULT_K : k);
|
||||||
});
|
});
|
||||||
|
|
||||||
static {
|
static {
|
||||||
|
@ -111,6 +111,10 @@ public class PrecisionAtK implements EvaluationMetric {
|
||||||
k = in.readVInt();
|
k = in.readVInt();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int getK() {
|
||||||
|
return this.k;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
out.writeVInt(relevantRatingThreshhold);
|
out.writeVInt(relevantRatingThreshhold);
|
||||||
|
|
|
@ -42,13 +42,18 @@ import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashC
|
||||||
|
|
||||||
public class DiscountedCumulativeGainTests extends ESTestCase {
|
public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||||
|
|
||||||
|
static final double EXPECTED_DCG = 13.84826362927298;
|
||||||
|
static final double EXPECTED_IDCG = 14.595390756454922;
|
||||||
|
static final double EXPECTED_NDCG = EXPECTED_DCG / EXPECTED_IDCG;
|
||||||
|
private static final double DELTA = 10E-16;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Assuming the docs are ranked in the following order:
|
* Assuming the docs are ranked in the following order:
|
||||||
*
|
*
|
||||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
|
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
|
||||||
* -------------------------------------------------------------------------------------------
|
* -------------------------------------------------------------------------------------------
|
||||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 |
|
* 1 | 3 | 7.0 | 1.0 | 7.0 | 7.0 |
|
||||||
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
|
* 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
|
||||||
* 3 | 3 | 7.0 | 2.0 | 3.5
|
* 3 | 3 | 7.0 | 2.0 | 3.5
|
||||||
* 4 | 0 | 0.0 | 2.321928094887362 | 0.0
|
* 4 | 0 | 0.0 | 2.321928094887362 | 0.0
|
||||||
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
|
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
|
||||||
|
@ -66,7 +71,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||||
hits[i].shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0, null));
|
hits[i].shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0, null));
|
||||||
}
|
}
|
||||||
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
||||||
assertEquals(13.84826362927298, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
|
assertEquals(EXPECTED_DCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check with normalization: to get the maximal possible dcg, sort documents by
|
* Check with normalization: to get the maximal possible dcg, sort documents by
|
||||||
|
@ -83,8 +88,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||||
*
|
*
|
||||||
* idcg = 14.595390756454922 (sum of last column)
|
* idcg = 14.595390756454922 (sum of last column)
|
||||||
*/
|
*/
|
||||||
dcg = new DiscountedCumulativeGain(true, null);
|
dcg = new DiscountedCumulativeGain(true, null, 10);
|
||||||
assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
|
assertEquals(EXPECTED_NDCG, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -117,7 +122,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
||||||
EvalQueryQuality result = dcg.evaluate("id", hits, rated);
|
EvalQueryQuality result = dcg.evaluate("id", hits, rated);
|
||||||
assertEquals(12.779642067948913, result.getQualityLevel(), 0.00001);
|
assertEquals(12.779642067948913, result.getQualityLevel(), DELTA);
|
||||||
assertEquals(2, filterUnknownDocuments(result.getHitsAndRatings()).size());
|
assertEquals(2, filterUnknownDocuments(result.getHitsAndRatings()).size());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -135,8 +140,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||||
*
|
*
|
||||||
* idcg = 13.347184833073591 (sum of last column)
|
* idcg = 13.347184833073591 (sum of last column)
|
||||||
*/
|
*/
|
||||||
dcg = new DiscountedCumulativeGain(true, null);
|
dcg = new DiscountedCumulativeGain(true, null, 10);
|
||||||
assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
|
assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), DELTA);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -174,7 +179,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
||||||
EvalQueryQuality result = dcg.evaluate("id", hits, ratedDocs);
|
EvalQueryQuality result = dcg.evaluate("id", hits, ratedDocs);
|
||||||
assertEquals(12.392789260714371, result.getQualityLevel(), 0.00001);
|
assertEquals(12.392789260714371, result.getQualityLevel(), DELTA);
|
||||||
assertEquals(1, filterUnknownDocuments(result.getHitsAndRatings()).size());
|
assertEquals(1, filterUnknownDocuments(result.getHitsAndRatings()).size());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -193,16 +198,27 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||||
*
|
*
|
||||||
* idcg = 13.347184833073591 (sum of last column)
|
* idcg = 13.347184833073591 (sum of last column)
|
||||||
*/
|
*/
|
||||||
dcg = new DiscountedCumulativeGain(true, null);
|
dcg = new DiscountedCumulativeGain(true, null, 10);
|
||||||
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001);
|
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), DELTA);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testParseFromXContent() throws IOException {
|
public void testParseFromXContent() throws IOException {
|
||||||
String xContent = " { \"unknown_doc_rating\": 2, \"normalize\": true }";
|
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true, \"k\" : 15 }", 2, true, 15);
|
||||||
|
assertParsedCorrect("{ \"normalize\": false, \"k\" : 15 }", null, false, 15);
|
||||||
|
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"k\" : 15 }", 2, false, 15);
|
||||||
|
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"normalize\": true }", 2, true, 10);
|
||||||
|
assertParsedCorrect("{ \"normalize\": true }", null, true, 10);
|
||||||
|
assertParsedCorrect("{ \"k\": 23 }", null, false, 23);
|
||||||
|
assertParsedCorrect("{ \"unknown_doc_rating\": 2 }", 2, false, 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRating, boolean expectedNormalize, int expectedK)
|
||||||
|
throws IOException {
|
||||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||||
DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser);
|
DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser);
|
||||||
assertEquals(2, dcgAt.getUnknownDocRating().intValue());
|
assertEquals(expectedUnknownDocRating, dcgAt.getUnknownDocRating());
|
||||||
assertEquals(true, dcgAt.getNormalize());
|
assertEquals(expectedNormalize, dcgAt.getNormalize());
|
||||||
|
assertEquals(expectedK, dcgAt.getK());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,7 +226,7 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||||
boolean normalize = randomBoolean();
|
boolean normalize = randomBoolean();
|
||||||
Integer unknownDocRating = new Integer(randomIntBetween(0, 1000));
|
Integer unknownDocRating = new Integer(randomIntBetween(0, 1000));
|
||||||
|
|
||||||
return new DiscountedCumulativeGain(normalize, unknownDocRating);
|
return new DiscountedCumulativeGain(normalize, unknownDocRating, 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContentRoundtrip() throws IOException {
|
public void testXContentRoundtrip() throws IOException {
|
||||||
|
@ -238,16 +254,22 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||||
|
|
||||||
public void testEqualsAndHash() throws IOException {
|
public void testEqualsAndHash() throws IOException {
|
||||||
checkEqualsAndHashCode(createTestItem(), original -> {
|
checkEqualsAndHashCode(createTestItem(), original -> {
|
||||||
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating());
|
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(), original.getK());
|
||||||
}, DiscountedCumulativeGainTests::mutateTestItem);
|
}, DiscountedCumulativeGainTests::mutateTestItem);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) {
|
private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) {
|
||||||
if (randomBoolean()) {
|
switch (randomIntBetween(0, 2)) {
|
||||||
return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating());
|
case 0:
|
||||||
} else {
|
return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating(), original.getK());
|
||||||
|
case 1:
|
||||||
return new DiscountedCumulativeGain(original.getNormalize(),
|
return new DiscountedCumulativeGain(original.getNormalize(),
|
||||||
randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)));
|
randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)), original.getK());
|
||||||
|
case 2:
|
||||||
|
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating(),
|
||||||
|
randomValueOtherThan(original.getK(), () -> randomIntBetween(1, 10)));
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("mutation variant not allowed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,12 +48,28 @@ public class MeanReciprocalRankTests extends ESTestCase {
|
||||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||||
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
||||||
assertEquals(1, mrr.getRelevantRatingThreshold());
|
assertEquals(1, mrr.getRelevantRatingThreshold());
|
||||||
|
assertEquals(10, mrr.getK());
|
||||||
}
|
}
|
||||||
|
|
||||||
xContent = "{ \"relevant_rating_threshold\": 2 }";
|
xContent = "{ \"relevant_rating_threshold\": 2 }";
|
||||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||||
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
||||||
assertEquals(2, mrr.getRelevantRatingThreshold());
|
assertEquals(2, mrr.getRelevantRatingThreshold());
|
||||||
|
assertEquals(10, mrr.getK());
|
||||||
|
}
|
||||||
|
|
||||||
|
xContent = "{ \"relevant_rating_threshold\": 2, \"k\" : 15 }";
|
||||||
|
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||||
|
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
||||||
|
assertEquals(2, mrr.getRelevantRatingThreshold());
|
||||||
|
assertEquals(15, mrr.getK());
|
||||||
|
}
|
||||||
|
|
||||||
|
xContent = "{ \"k\" : 15 }";
|
||||||
|
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||||
|
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
||||||
|
assertEquals(1, mrr.getRelevantRatingThreshold());
|
||||||
|
assertEquals(15, mrr.getK());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,7 +132,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
|
||||||
rated.add(new RatedDocument("test", "4", 4));
|
rated.add(new RatedDocument("test", "4", 4));
|
||||||
SearchHit[] hits = createSearchHits(0, 5, "test");
|
SearchHit[] hits = createSearchHits(0, 5, "test");
|
||||||
|
|
||||||
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2);
|
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2, 10);
|
||||||
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated);
|
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated);
|
||||||
assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001);
|
assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001);
|
||||||
assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
|
assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
|
||||||
|
@ -167,7 +183,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
static MeanReciprocalRank createTestItem() {
|
static MeanReciprocalRank createTestItem() {
|
||||||
return new MeanReciprocalRank(randomIntBetween(0, 20));
|
return new MeanReciprocalRank(randomIntBetween(0, 20), randomIntBetween(1, 20));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSerialization() throws IOException {
|
public void testSerialization() throws IOException {
|
||||||
|
@ -184,14 +200,22 @@ public class MeanReciprocalRankTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static MeanReciprocalRank copy(MeanReciprocalRank testItem) {
|
private static MeanReciprocalRank copy(MeanReciprocalRank testItem) {
|
||||||
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold());
|
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) {
|
private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) {
|
||||||
return new MeanReciprocalRank(randomValueOtherThan(testItem.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)));
|
if (randomBoolean()) {
|
||||||
|
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold() + 1, testItem.getK());
|
||||||
|
} else {
|
||||||
|
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold(), testItem.getK() + 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testInvalidRelevantThreshold() {
|
public void testInvalidRelevantThreshold() {
|
||||||
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1));
|
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInvalidK() {
|
||||||
|
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(1, -1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,18 +20,17 @@
|
||||||
package org.elasticsearch.index.rankeval;
|
package org.elasticsearch.index.rankeval;
|
||||||
|
|
||||||
import org.elasticsearch.ElasticsearchException;
|
import org.elasticsearch.ElasticsearchException;
|
||||||
import org.elasticsearch.client.Client;
|
|
||||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||||
import org.elasticsearch.index.query.QueryBuilders;
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.plugins.Plugin;
|
import org.elasticsearch.plugins.Plugin;
|
||||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
import org.elasticsearch.search.sort.FieldSortBuilder;
|
|
||||||
import org.elasticsearch.test.ESIntegTestCase;
|
import org.elasticsearch.test.ESIntegTestCase;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map.Entry;
|
import java.util.Map.Entry;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
@ -64,13 +63,16 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
||||||
refresh();
|
refresh();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test cases retrieves all six documents indexed above. The first part checks the Prec@10 calculation where
|
||||||
|
* all unlabeled docs are treated as "unrelevant". We average Prec@ metric across two search use cases, the
|
||||||
|
* first one that labels 4 out of the 6 documents as relevant, the second one with only one relevant document.
|
||||||
|
*/
|
||||||
public void testPrecisionAtRequest() {
|
public void testPrecisionAtRequest() {
|
||||||
List<String> indices = Arrays.asList(new String[] { "test" });
|
|
||||||
|
|
||||||
List<RatedRequest> specifications = new ArrayList<>();
|
List<RatedRequest> specifications = new ArrayList<>();
|
||||||
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
||||||
testQuery.query(new MatchAllQueryBuilder());
|
testQuery.query(new MatchAllQueryBuilder());
|
||||||
testQuery.sort(FieldSortBuilder.DOC_FIELD_NAME);
|
testQuery.sort("_id");
|
||||||
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query",
|
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query",
|
||||||
createRelevant("2", "3", "4", "5"), testQuery);
|
createRelevant("2", "3", "4", "5"), testQuery);
|
||||||
amsterdamRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
|
amsterdamRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
|
||||||
|
@ -79,12 +81,11 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
||||||
RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("1"),
|
RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("1"),
|
||||||
testQuery);
|
testQuery);
|
||||||
berlinRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
|
berlinRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
|
||||||
|
|
||||||
specifications.add(berlinRequest);
|
specifications.add(berlinRequest);
|
||||||
|
|
||||||
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
|
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
|
||||||
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||||
task.addIndices(indices);
|
task.addIndices(Collections.singletonList("test"));
|
||||||
|
|
||||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
|
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
|
||||||
RankEvalAction.INSTANCE, new RankEvalRequest());
|
RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||||
|
@ -92,6 +93,8 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
||||||
|
|
||||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
|
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
|
||||||
.actionGet();
|
.actionGet();
|
||||||
|
// the expected Prec@ for the first query is 4/6 and the expected Prec@ for the
|
||||||
|
// second is 1/6, divided by 2 to get the average
|
||||||
double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0;
|
double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0;
|
||||||
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
|
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||||
Set<Entry<String, EvalQueryQuality>> entrySet = response.getPartialResults().entrySet();
|
Set<Entry<String, EvalQueryQuality>> entrySet = response.getPartialResults().entrySet();
|
||||||
|
@ -129,14 +132,96 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
||||||
// test that a different window size k affects the result
|
// test that a different window size k affects the result
|
||||||
metric = new PrecisionAtK(1, false, 3);
|
metric = new PrecisionAtK(1, false, 3);
|
||||||
task = new RankEvalSpec(specifications, metric);
|
task = new RankEvalSpec(specifications, metric);
|
||||||
task.addIndices(indices);
|
task.addIndices(Collections.singletonList("test"));
|
||||||
|
|
||||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||||
builder.setRankEvalSpec(task);
|
builder.setRankEvalSpec(task);
|
||||||
|
|
||||||
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||||
|
// if we look only at top 3 documente, the expected P@3 for the first query is
|
||||||
|
// 2/3 and the expected Prec@ for the second is 1/3, divided by 2 to get the average
|
||||||
expectedPrecision = (1.0 / 3.0 + 2.0 / 3.0) / 2.0;
|
expectedPrecision = (1.0 / 3.0 + 2.0 / 3.0) / 2.0;
|
||||||
assertEquals(0.5, response.getEvaluationResult(), Double.MIN_VALUE);
|
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This test assumes we are using the same ratings as in {@link DiscountedCumulativeGainTests#testDCGAt()}.
|
||||||
|
* See details in that test case for how the expected values are calculated
|
||||||
|
*/
|
||||||
|
public void testDCGRequest() {
|
||||||
|
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
||||||
|
testQuery.query(new MatchAllQueryBuilder());
|
||||||
|
testQuery.sort("_id");
|
||||||
|
|
||||||
|
List<RatedRequest> specifications = new ArrayList<>();
|
||||||
|
List<RatedDocument> ratedDocs = Arrays.asList(
|
||||||
|
new RatedDocument("test", "1", 3),
|
||||||
|
new RatedDocument("test", "2", 2),
|
||||||
|
new RatedDocument("test", "3", 3),
|
||||||
|
new RatedDocument("test", "4", 0),
|
||||||
|
new RatedDocument("test", "5", 1),
|
||||||
|
new RatedDocument("test", "6", 2));
|
||||||
|
specifications.add(new RatedRequest("amsterdam_query", ratedDocs, testQuery));
|
||||||
|
|
||||||
|
DiscountedCumulativeGain metric = new DiscountedCumulativeGain(false, null, 10);
|
||||||
|
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||||
|
task.addIndices(Collections.singletonList("test"));
|
||||||
|
|
||||||
|
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||||
|
builder.setRankEvalSpec(task);
|
||||||
|
|
||||||
|
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||||
|
assertEquals(DiscountedCumulativeGainTests.EXPECTED_DCG, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||||
|
|
||||||
|
// test that a different window size k affects the result
|
||||||
|
metric = new DiscountedCumulativeGain(false, null, 3);
|
||||||
|
task = new RankEvalSpec(specifications, metric);
|
||||||
|
task.addIndices(Collections.singletonList("test"));
|
||||||
|
|
||||||
|
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||||
|
builder.setRankEvalSpec(task);
|
||||||
|
|
||||||
|
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||||
|
assertEquals(12.392789260714371, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMRRRequest() {
|
||||||
|
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
||||||
|
testQuery.query(new MatchAllQueryBuilder());
|
||||||
|
testQuery.sort("_id");
|
||||||
|
|
||||||
|
List<RatedRequest> specifications = new ArrayList<>();
|
||||||
|
specifications.add(new RatedRequest("amsterdam_query", createRelevant("5"), testQuery));
|
||||||
|
specifications.add(new RatedRequest("berlin_query", createRelevant("1"), testQuery));
|
||||||
|
|
||||||
|
MeanReciprocalRank metric = new MeanReciprocalRank(1, 10);
|
||||||
|
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||||
|
task.addIndices(Collections.singletonList("test"));
|
||||||
|
|
||||||
|
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||||
|
builder.setRankEvalSpec(task);
|
||||||
|
|
||||||
|
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||||
|
// the expected reciprocal rank for the amsterdam_query is 1/5
|
||||||
|
// the expected reciprocal rank for the berlin_query is 1/1
|
||||||
|
// dividing by 2 to get the average
|
||||||
|
double expectedMRR = (1.0 / 1.0 + 1.0 / 5.0) / 2.0;
|
||||||
|
assertEquals(expectedMRR, response.getEvaluationResult(), 0.0);
|
||||||
|
|
||||||
|
// test that a different window size k affects the result
|
||||||
|
metric = new MeanReciprocalRank(1, 3);
|
||||||
|
task = new RankEvalSpec(specifications, metric);
|
||||||
|
task.addIndices(Collections.singletonList("test"));
|
||||||
|
|
||||||
|
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||||
|
builder.setRankEvalSpec(task);
|
||||||
|
|
||||||
|
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||||
|
// limiting to top 3 results, the amsterdam_query has no relevant document in it
|
||||||
|
// the reciprocal rank for the berlin_query is 1/1
|
||||||
|
// dividing by 2 to get the average
|
||||||
|
expectedMRR = (1.0/ 1.0) / 2.0;
|
||||||
|
assertEquals(expectedMRR, response.getEvaluationResult(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -162,16 +247,13 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
||||||
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());
|
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());
|
||||||
task.addIndices(indices);
|
task.addIndices(indices);
|
||||||
|
|
||||||
try (Client client = client()) {
|
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client, RankEvalAction.INSTANCE, new RankEvalRequest());
|
|
||||||
builder.setRankEvalSpec(task);
|
builder.setRankEvalSpec(task);
|
||||||
|
|
||||||
RankEvalResponse response = client.execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||||
assertEquals(1, response.getFailures().size());
|
assertEquals(1, response.getFailures().size());
|
||||||
ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
|
ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
|
||||||
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"",
|
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"", rootCauses[0].getCause().toString());
|
||||||
rootCauses[0].getCause().toString());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<RatedDocument> createRelevant(String... docs) {
|
private static List<RatedDocument> createRelevant(String... docs) {
|
||||||
|
|
Loading…
Reference in New Issue