Adds recall@k metric to rank eval API (#52889)

This change adds the recall@k metric and refactors precision@k to match
the new metric.

Recall@k is an important metric to use for learning to rank (LTR)
use-cases. Candidate generation or first ranking phase ranking functions
are often optimized for high recall, in order to generate as many
relevant candidates in the top-k as possible for a second phase of
ranking. Adding this metric allows tuning that base query for LTR.

See: https://github.com/elastic/elasticsearch/issues/51676
Backports: https://github.com/elastic/elasticsearch/pull/52577
This commit is contained in:
Josh Devins 2020-02-27 16:04:24 +01:00 committed by GitHub
parent 3c8b46a8c1
commit 68ba571f70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 748 additions and 132 deletions

View File

@ -28,6 +28,7 @@ import org.elasticsearch.index.rankeval.EvaluationMetric;
import org.elasticsearch.index.rankeval.ExpectedReciprocalRank;
import org.elasticsearch.index.rankeval.MeanReciprocalRank;
import org.elasticsearch.index.rankeval.PrecisionAtK;
import org.elasticsearch.index.rankeval.RecallAtK;
import org.elasticsearch.index.rankeval.RankEvalRequest;
import org.elasticsearch.index.rankeval.RankEvalResponse;
import org.elasticsearch.index.rankeval.RankEvalSpec;
@ -130,9 +131,9 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase {
*/
public void testMetrics() throws IOException {
List<RatedRequest> specifications = createTestEvaluationSpec();
List<Supplier<EvaluationMetric>> metrics = Arrays.asList(PrecisionAtK::new, MeanReciprocalRank::new, DiscountedCumulativeGain::new,
() -> new ExpectedReciprocalRank(1));
double expectedScores[] = new double[] {0.4285714285714286, 0.75, 1.6408962261063627, 0.4407738095238095};
List<Supplier<EvaluationMetric>> metrics = Arrays.asList(PrecisionAtK::new, RecallAtK::new,
MeanReciprocalRank::new, DiscountedCumulativeGain::new, () -> new ExpectedReciprocalRank(1));
double expectedScores[] = new double[] {0.4285714285714286, 1.0, 0.75, 1.6408962261063627, 0.4407738095238095};
int i = 0;
for (Supplier<EvaluationMetric> metricSupplier : metrics) {
RankEvalSpec spec = new RankEvalSpec(specifications, metricSupplier.get());

View File

@ -98,6 +98,7 @@ import org.elasticsearch.index.rankeval.ExpectedReciprocalRank;
import org.elasticsearch.index.rankeval.MeanReciprocalRank;
import org.elasticsearch.index.rankeval.MetricDetail;
import org.elasticsearch.index.rankeval.PrecisionAtK;
import org.elasticsearch.index.rankeval.RecallAtK;
import org.elasticsearch.join.aggregations.ChildrenAggregationBuilder;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHits;
@ -696,7 +697,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(57, namedXContents.size());
assertEquals(59, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -710,13 +711,15 @@ public class RestHighLevelClientTests extends ESTestCase {
assertEquals(Integer.valueOf(3), categories.get(Aggregation.class));
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
assertEquals(Integer.valueOf(4), categories.get(EvaluationMetric.class));
assertEquals(Integer.valueOf(5), categories.get(EvaluationMetric.class));
assertTrue(names.contains(PrecisionAtK.NAME));
assertTrue(names.contains(RecallAtK.NAME));
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
assertTrue(names.contains(MeanReciprocalRank.NAME));
assertTrue(names.contains(ExpectedReciprocalRank.NAME));
assertEquals(Integer.valueOf(4), categories.get(MetricDetail.class));
assertEquals(Integer.valueOf(5), categories.get(MetricDetail.class));
assertTrue(names.contains(PrecisionAtK.NAME));
assertTrue(names.contains(RecallAtK.NAME));
assertTrue(names.contains(MeanReciprocalRank.NAME));
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
assertTrue(names.contains(ExpectedReciprocalRank.NAME));

View File

@ -203,20 +203,21 @@ will be used. The following metrics are supported:
[[k-precision]]
===== Precision at K (P@k)
This metric measures the number of relevant results in the top k search results.
It's a form of the well-known
https://en.wikipedia.org/wiki/Information_retrieval#Precision[Precision] metric
that only looks at the top k documents. It is the fraction of relevant documents
in those first k results. A precision at 10 (P@10) value of 0.6 then means six
out of the 10 top hits are relevant with respect to the user's information need.
This metric measures the proportion of relevant results in the top k search results.
It's a form of the well-known
https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision[Precision]
metric that only looks at the top k documents. It is the fraction of relevant
documents in those first k results. A precision at 10 (P@10) value of 0.6 then
means 6 out of the 10 top hits are relevant with respect to the user's
information need.
P@k works well as a simple evaluation metric that has the benefit of being easy
to understand and explain. Documents in the collection need to be rated as either
relevant or irrelevant with respect to the current query. P@k does not take
into account the position of the relevant documents within the top k results,
so a ranking of ten results that contains one relevant result in position 10 is
equally as good as a ranking of ten results that contains one relevant result
in position 1.
P@k works well as a simple evaluation metric that has the benefit of being easy
to understand and explain. Documents in the collection need to be rated as either
relevant or irrelevant with respect to the current query. P@k is a set-based
metric and does not take into account the position of the relevant documents
within the top k results, so a ranking of ten results that contains one
relevant result in position 10 is equally as good as a ranking of ten results
that contains one relevant result in position 1.
[source,console]
--------------------------------
@ -253,6 +254,58 @@ If set to 'true', unlabeled documents are ignored and neither count as relevant
|=======================================================================
[float]
[[k-recall]]
===== Recall at K (R@k)
This metric measures the total number of relevant results in the top k search
results. It's a form of the well-known
https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall[Recall]
metric. It is the fraction of relevant documents in those first k results
relative to all possible relevant results. A recall at 10 (R@10) value of 0.5 then
means 4 out of 8 relevant documents, with respect to the user's information
need, were retrieved in the 10 top hits.
R@k works well as a simple evaluation metric that has the benefit of being easy
to understand and explain. Documents in the collection need to be rated as either
relevant or irrelevant with respect to the current query. R@k is a set-based
metric and does not take into account the position of the relevant documents
within the top k results, so a ranking of ten results that contains one
relevant result in position 10 is equally as good as a ranking of ten results
that contains one relevant result in position 1.
[source,console]
--------------------------------
GET /twitter/_rank_eval
{
"requests": [
{
"id": "JFK query",
"request": { "query": { "match_all": {}}},
"ratings": []
}],
"metric": {
"recall": {
"k" : 20,
"relevant_rating_threshold": 1
}
}
}
--------------------------------
// TEST[setup:twitter]
The `recall` metric takes the following optional parameters
[cols="<,<",options="header",]
|=======================================================================
|Parameter |Description
|`k` |sets the maximum number of documents retrieved per query. This value will act in place of the usual `size` parameter
in the query. Defaults to 10.
|`relevant_rating_threshold` |sets the rating threshold above which documents are considered to be
"relevant". Defaults to `1`.
|=======================================================================
[float]
===== Mean reciprocal rank

View File

@ -26,7 +26,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
/**
* Details about a specific {@link EvaluationMetric} that should be included in the resonse.
* Details about a specific {@link EvaluationMetric} that should be included in the response.
*/
public interface MetricDetail extends ToXContentObject, NamedWriteable {

View File

@ -40,10 +40,10 @@ import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRati
/**
* Metric implementing Precision@K
* (https://en.wikipedia.org/wiki/Information_retrieval#Precision_at_K).<br>
* (https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision).<br>
* By default documents with a rating equal or bigger than 1 are considered to
* be "relevant" for this calculation. This value can be changes using the
* relevant_rating_threshold` parameter.<br>
* be "relevant" for this calculation. This value can be changed using the
* `relevant_rating_threshold` parameter.<br>
* The `ignore_unlabeled` parameter (default to false) controls if unrated
* documents should be ignored.
* The `k` parameter (defaults to 10) controls the search window size.
@ -52,19 +52,21 @@ public class PrecisionAtK implements EvaluationMetric {
public static final String NAME = "precision";
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
private static final int DEFAULT_RELEVANT_RATING_THRESHOLD = 1;
private static final boolean DEFAULT_IGNORE_UNLABELED = false;
private static final int DEFAULT_K = 10;
private static final ParseField RELEVANT_RATING_THRESHOLD_FIELD = new ParseField("relevant_rating_threshold");
private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled");
private static final ParseField K_FIELD = new ParseField("k");
private static final int DEFAULT_K = 10;
private final int relevantRatingThreshold;
private final boolean ignoreUnlabeled;
private final int relevantRatingThreshhold;
private final int k;
/**
* Metric implementing Precision@K.
* @param threshold
* @param relevantRatingThreshold
* ratings equal or above this value will be considered relevant.
* @param ignoreUnlabeled
* Controls how unlabeled documents in the search hits are treated.
@ -74,53 +76,67 @@ public class PrecisionAtK implements EvaluationMetric {
* @param k
* controls the window size for the search results the metric takes into account
*/
public PrecisionAtK(int threshold, boolean ignoreUnlabeled, int k) {
if (threshold < 0) {
public PrecisionAtK(int relevantRatingThreshold, boolean ignoreUnlabeled, int k) {
if (relevantRatingThreshold < 0) {
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
}
if (k <= 0) {
throw new IllegalArgumentException("Window size k must be positive.");
}
this.relevantRatingThreshhold = threshold;
this.relevantRatingThreshold = relevantRatingThreshold;
this.ignoreUnlabeled = ignoreUnlabeled;
this.k = k;
}
public PrecisionAtK() {
this(1, false, DEFAULT_K);
public PrecisionAtK(boolean ignoreUnlabeled) {
this(DEFAULT_RELEVANT_RATING_THRESHOLD, ignoreUnlabeled, DEFAULT_K);
}
private static final ConstructingObjectParser<PrecisionAtK, Void> PARSER = new ConstructingObjectParser<>(NAME,
args -> {
Integer threshHold = (Integer) args[0];
Boolean ignoreUnlabeled = (Boolean) args[1];
Integer k = (Integer) args[2];
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
ignoreUnlabeled == null ? false : ignoreUnlabeled,
k == null ? DEFAULT_K : k);
});
public PrecisionAtK() {
this(DEFAULT_RELEVANT_RATING_THRESHOLD, DEFAULT_IGNORE_UNLABELED, DEFAULT_K);
}
PrecisionAtK(StreamInput in) throws IOException {
this(in.readVInt(), in.readBoolean(), in.readVInt());
}
private static final ConstructingObjectParser<PrecisionAtK, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
Integer relevantRatingThreshold = (Integer) args[0];
Boolean ignoreUnlabeled = (Boolean) args[1];
Integer k = (Integer) args[2];
return new PrecisionAtK(
relevantRatingThreshold == null ? DEFAULT_RELEVANT_RATING_THRESHOLD : relevantRatingThreshold,
ignoreUnlabeled == null ? DEFAULT_IGNORE_UNLABELED : ignoreUnlabeled,
k == null ? DEFAULT_K : k);
});
static {
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_THRESHOLD_FIELD);
PARSER.declareBoolean(optionalConstructorArg(), IGNORE_UNLABELED_FIELD);
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
}
PrecisionAtK(StreamInput in) throws IOException {
relevantRatingThreshhold = in.readVInt();
ignoreUnlabeled = in.readBoolean();
k = in.readVInt();
}
int getK() {
return this.k;
public static PrecisionAtK fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(relevantRatingThreshhold);
out.writeBoolean(ignoreUnlabeled);
out.writeVInt(k);
out.writeVInt(getRelevantRatingThreshold());
out.writeBoolean(getIgnoreUnlabeled());
out.writeVInt(getK());
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startObject(NAME);
builder.field(RELEVANT_RATING_THRESHOLD_FIELD.getPreferredName(), getRelevantRatingThreshold());
builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), getIgnoreUnlabeled());
builder.field(K_FIELD.getPreferredName(), getK());
builder.endObject();
builder.endObject();
return builder;
}
@Override
@ -133,7 +149,7 @@ public class PrecisionAtK implements EvaluationMetric {
* "relevant" for this metric. Defaults to 1.
*/
public int getRelevantRatingThreshold() {
return relevantRatingThreshhold;
return relevantRatingThreshold;
}
/**
@ -143,61 +159,66 @@ public class PrecisionAtK implements EvaluationMetric {
return ignoreUnlabeled;
}
@Override
public OptionalInt forcedSearchSize() {
return OptionalInt.of(k);
public int getK() {
return k;
}
public static PrecisionAtK fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
@Override
public OptionalInt forcedSearchSize() {
return OptionalInt.of(getK());
}
/**
* Compute precisionAtN based on provided relevant document IDs.
* Binarizes a rating based on the relevant rating threshold.
*/
private boolean isRelevant(int rating) {
return rating >= getRelevantRatingThreshold();
}
/**
* Should we count unlabeled documents? This is the inverse of {@link #getIgnoreUnlabeled()}.
*/
private boolean shouldCountUnlabeled() {
return !getIgnoreUnlabeled();
}
/**
* Compute precision at k based on provided relevant document IDs.
*
* @return precision at n for above {@link SearchResult} list.
* @return precision at k for above {@link SearchResult} list.
**/
@Override
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
List<RatedDocument> ratedDocs) {
int truePositives = 0;
int falsePositives = 0;
List<RatedDocument> ratedDocs) {
List<RatedSearchHit> ratedSearchHits = joinHitsWithRatings(hits, ratedDocs);
int relevantRetrieved = 0;
int retrieved = 0;
for (RatedSearchHit hit : ratedSearchHits) {
OptionalInt rating = hit.getRating();
if (rating.isPresent()) {
if (rating.getAsInt() >= this.relevantRatingThreshhold) {
truePositives++;
} else {
falsePositives++;
retrieved++;
if (isRelevant(rating.getAsInt())) {
relevantRetrieved++;
}
} else if (ignoreUnlabeled == false) {
falsePositives++;
} else if (shouldCountUnlabeled()) {
retrieved++;
}
}
double precision = 0.0;
if (truePositives + falsePositives > 0) {
precision = (double) truePositives / (truePositives + falsePositives);
if (retrieved > 0) {
precision = (double) relevantRetrieved / retrieved;
}
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, precision);
evalQueryQuality.setMetricDetails(
new PrecisionAtK.Detail(truePositives, truePositives + falsePositives));
evalQueryQuality.setMetricDetails(new PrecisionAtK.Detail(relevantRetrieved, retrieved));
evalQueryQuality.addHitsAndRatings(ratedSearchHits);
return evalQueryQuality;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startObject(NAME);
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.ignoreUnlabeled);
builder.field(K_FIELD.getPreferredName(), this.k);
builder.endObject();
builder.endObject();
return builder;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
@ -207,20 +228,21 @@ public class PrecisionAtK implements EvaluationMetric {
return false;
}
PrecisionAtK other = (PrecisionAtK) obj;
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold)
&& Objects.equals(k, other.k)
&& Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled);
return Objects.equals(relevantRatingThreshold, other.relevantRatingThreshold)
&& Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled)
&& Objects.equals(k, other.k);
}
@Override
public final int hashCode() {
return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled, k);
return Objects.hash(relevantRatingThreshold, ignoreUnlabeled, k);
}
public static final class Detail implements MetricDetail {
private static final ParseField DOCS_RETRIEVED_FIELD = new ParseField("docs_retrieved");
private static final ParseField RELEVANT_DOCS_RETRIEVED_FIELD = new ParseField("relevant_docs_retrieved");
private static final ParseField DOCS_RETRIEVED_FIELD = new ParseField("docs_retrieved");
private int relevantRetrieved;
private int retrieved;
@ -230,21 +252,11 @@ public class PrecisionAtK implements EvaluationMetric {
}
Detail(StreamInput in) throws IOException {
this.relevantRetrieved = in.readVInt();
this.retrieved = in.readVInt();
this(in.readVInt(), in.readVInt());
}
@Override
public XContentBuilder innerToXContent(XContentBuilder builder, Params params)
throws IOException {
builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved);
builder.field(DOCS_RETRIEVED_FIELD.getPreferredName(), retrieved);
return builder;
}
private static final ConstructingObjectParser<Detail, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
return new Detail((Integer) args[0], (Integer) args[1]);
});
private static final ConstructingObjectParser<Detail, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, args -> new Detail((Integer) args[0], (Integer) args[1]));
static {
PARSER.declareInt(constructorArg(), RELEVANT_DOCS_RETRIEVED_FIELD);
@ -257,8 +269,16 @@ public class PrecisionAtK implements EvaluationMetric {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(relevantRetrieved);
out.writeVInt(retrieved);
out.writeVLong(relevantRetrieved);
out.writeVLong(retrieved);
}
@Override
public XContentBuilder innerToXContent(XContentBuilder builder, Params params)
throws IOException {
builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved);
builder.field(DOCS_RETRIEVED_FIELD.getPreferredName(), retrieved);
return builder;
}
@Override

View File

@ -31,8 +31,10 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider {
@Override
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.add(
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME),
PrecisionAtK::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallAtK.NAME),
RecallAtK::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanReciprocalRank.NAME),
MeanReciprocalRank::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME),
@ -42,6 +44,8 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider {
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME),
PrecisionAtK.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(RecallAtK.NAME),
RecallAtK.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME),
MeanReciprocalRank.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME),

View File

@ -58,12 +58,14 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, RecallAtK.NAME, RecallAtK::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(EvaluationMetric.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, RecallAtK.NAME, RecallAtK.Detail::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new));

View File

@ -0,0 +1,280 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchHit;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.OptionalInt;
import javax.naming.directory.SearchResult;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
/**
* Metric implementing Recall@K
* (https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall).<br>
* By default documents with a rating equal or bigger than 1 are considered to
* be "relevant" for this calculation. This value can be changed using the
* `relevant_rating_threshold` parameter.<br>
* The `k` parameter (defaults to 10) controls the search window size.
*/
public class RecallAtK implements EvaluationMetric {
public static final String NAME = "recall";
private static final int DEFAULT_RELEVANT_RATING_THRESHOLD = 1;
private static final int DEFAULT_K = 10;
private static final ParseField RELEVANT_RATING_THRESHOLD_FIELD = new ParseField("relevant_rating_threshold");
private static final ParseField K_FIELD = new ParseField("k");
private final int relevantRatingThreshold;
private final int k;
/**
* Metric implementing Recall@K.
* @param relevantRatingThreshold
* ratings equal or above this value will be considered relevant.
* @param k
* controls the window size for the search results the metric takes into account
*/
public RecallAtK(int relevantRatingThreshold, int k) {
if (relevantRatingThreshold < 0) {
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
}
if (k <= 0) {
throw new IllegalArgumentException("Window size k must be positive.");
}
this.relevantRatingThreshold = relevantRatingThreshold;
this.k = k;
}
public RecallAtK() {
this(DEFAULT_RELEVANT_RATING_THRESHOLD, DEFAULT_K);
}
RecallAtK(StreamInput in) throws IOException {
this(in.readVInt(), in.readVInt());
}
private static final ConstructingObjectParser<RecallAtK, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
Integer relevantRatingThreshold = (Integer) args[0];
Integer k = (Integer) args[1];
return new RecallAtK(
relevantRatingThreshold == null ? DEFAULT_RELEVANT_RATING_THRESHOLD : relevantRatingThreshold,
k == null ? DEFAULT_K : k);
});
static {
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_THRESHOLD_FIELD);
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
}
public static RecallAtK fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(getRelevantRatingThreshold());
out.writeVInt(getK());
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startObject(NAME);
builder.field(RELEVANT_RATING_THRESHOLD_FIELD.getPreferredName(), getRelevantRatingThreshold());
builder.field(K_FIELD.getPreferredName(), getK());
builder.endObject();
builder.endObject();
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
/**
* Return the rating threshold above which ratings are considered to be
* "relevant" for this metric. Defaults to 1.
*/
public int getRelevantRatingThreshold() {
return relevantRatingThreshold;
}
public int getK() {
return k;
}
@Override
public OptionalInt forcedSearchSize() {
return OptionalInt.of(getK());
}
/**
* Binarizes a rating based on the relevant rating threshold.
*/
private boolean isRelevant(int rating) {
return rating >= getRelevantRatingThreshold();
}
/**
* Compute recall at k based on provided relevant document IDs.
*
* @return recall at k for above {@link SearchResult} list.
**/
@Override
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
List<RatedDocument> ratedDocs) {
List<RatedSearchHit> ratedSearchHits = joinHitsWithRatings(hits, ratedDocs);
int relevantRetrieved = 0;
for (RatedSearchHit hit : ratedSearchHits) {
OptionalInt rating = hit.getRating();
if (rating.isPresent() && isRelevant(rating.getAsInt())) {
relevantRetrieved++;
}
}
int relevant = 0;
for (RatedDocument rd : ratedDocs) {
if(isRelevant(rd.getRating())) {
relevant++;
}
}
double recall = 0.0;
if (relevant > 0) {
recall = (double) relevantRetrieved / relevant;
}
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, recall);
evalQueryQuality.setMetricDetails(new RecallAtK.Detail(relevantRetrieved, relevant));
evalQueryQuality.addHitsAndRatings(ratedSearchHits);
return evalQueryQuality;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
RecallAtK other = (RecallAtK) obj;
return Objects.equals(relevantRatingThreshold, other.relevantRatingThreshold)
&& Objects.equals(k, other.k);
}
@Override
public final int hashCode() {
return Objects.hash(relevantRatingThreshold, k);
}
public static final class Detail implements MetricDetail {
private static final ParseField RELEVANT_DOCS_RETRIEVED_FIELD = new ParseField("relevant_docs_retrieved");
private static final ParseField RELEVANT_DOCS_FIELD = new ParseField("relevant_docs");
private long relevantRetrieved;
private long relevant;
Detail(long relevantRetrieved, long relevant) {
this.relevantRetrieved = relevantRetrieved;
this.relevant = relevant;
}
Detail(StreamInput in) throws IOException {
this.relevantRetrieved = in.readVLong();
this.relevant = in.readVLong();
}
private static final ConstructingObjectParser<Detail, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, args -> new Detail((Integer) args[0], (Integer) args[1]));
static {
PARSER.declareInt(constructorArg(), RELEVANT_DOCS_RETRIEVED_FIELD);
PARSER.declareInt(constructorArg(), RELEVANT_DOCS_FIELD);
}
public static Detail fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(relevantRetrieved);
out.writeVLong(relevant);
}
@Override
public XContentBuilder innerToXContent(XContentBuilder builder, Params params)
throws IOException {
builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved);
builder.field(RELEVANT_DOCS_FIELD.getPreferredName(), relevant);
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
public long getRelevantRetrieved() {
return relevantRetrieved;
}
public long getRelevant() {
return relevant;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
RecallAtK.Detail other = (RecallAtK.Detail) obj;
return Objects.equals(relevantRetrieved, other.relevantRetrieved)
&& Objects.equals(relevant, other.relevant);
}
@Override
public int hashCode() {
return Objects.hash(relevantRetrieved, relevant);
}
}
}

View File

@ -48,25 +48,25 @@ import static org.hamcrest.CoreMatchers.containsString;
public class PrecisionAtKTests extends ESTestCase {
private static final int IRRELEVANT_RATING_0 = 0;
private static final int RELEVANT_RATING_1 = 1;
private static final int IRRELEVANT_RATING = 0;
private static final int RELEVANT_RATING = 1;
public void testPrecisionAtFiveCalculation() {
public void testCalculation() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1));
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals(1, evaluated.metricScore(), 0.00001);
assertEquals(1, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(1, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
}
public void testPrecisionAtFiveIgnoreOneResult() {
public void testIgnoreOneResult() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1));
rated.add(createRatedDoc("test", "1", RELEVANT_RATING_1));
rated.add(createRatedDoc("test", "2", RELEVANT_RATING_1));
rated.add(createRatedDoc("test", "3", RELEVANT_RATING_1));
rated.add(createRatedDoc("test", "4", IRRELEVANT_RATING_0));
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
rated.add(createRatedDoc("test", "1", RELEVANT_RATING));
rated.add(createRatedDoc("test", "2", RELEVANT_RATING));
rated.add(createRatedDoc("test", "3", RELEVANT_RATING));
rated.add(createRatedDoc("test", "4", IRRELEVANT_RATING));
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals((double) 4 / 5, evaluated.metricScore(), 0.00001);
assertEquals(4, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
@ -75,10 +75,10 @@ public class PrecisionAtKTests extends ESTestCase {
/**
* test that the relevant rating threshold can be set to something larger than
* 1. e.g. we set it to 2 here and expect dics 0-2 to be not relevant, doc 3 and
* 4 to be relevant
* 1. e.g. we set it to 2 here and expect docs 0-1 to be not relevant, docs 2-4
* to be relevant
*/
public void testPrecisionAtFiveRelevanceThreshold() {
public void testRelevanceThreshold() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(createRatedDoc("test", "0", 0));
rated.add(createRatedDoc("test", "1", 1));
@ -94,13 +94,15 @@ public class PrecisionAtKTests extends ESTestCase {
public void testPrecisionAtFiveCorrectIndex() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(createRatedDoc("test_other", "0", RELEVANT_RATING_1));
rated.add(createRatedDoc("test_other", "1", RELEVANT_RATING_1));
rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1));
rated.add(createRatedDoc("test", "1", RELEVANT_RATING_1));
rated.add(createRatedDoc("test", "2", IRRELEVANT_RATING_0));
rated.add(createRatedDoc("test_other", "0", RELEVANT_RATING));
rated.add(createRatedDoc("test_other", "1", RELEVANT_RATING));
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
rated.add(createRatedDoc("test", "1", RELEVANT_RATING));
rated.add(createRatedDoc("test", "2", IRRELEVANT_RATING));
// the following search hits contain only the last three documents
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated.subList(2, 5), "test"), rated);
List<RatedDocument> ratedSubList = rated.subList(2, 5);
PrecisionAtK precisionAtK = new PrecisionAtK(1, false, 5);
EvalQueryQuality evaluated = (precisionAtK).evaluate("id", toSearchHits(ratedSubList, "test"), rated);
assertEquals((double) 2 / 3, evaluated.metricScore(), 0.00001);
assertEquals(2, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
@ -108,8 +110,8 @@ public class PrecisionAtKTests extends ESTestCase {
public void testIgnoreUnlabeled() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1));
rated.add(createRatedDoc("test", "1", RELEVANT_RATING_1));
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
rated.add(createRatedDoc("test", "1", RELEVANT_RATING));
// add an unlabeled search hit
SearchHit[] searchHits = Arrays.copyOf(toSearchHits(rated, "test"), 3);
searchHits[2] = new SearchHit(2, "2", new Text(MapperService.SINGLE_MAPPING_NAME), Collections.emptyMap());
@ -121,7 +123,7 @@ public class PrecisionAtKTests extends ESTestCase {
assertEquals(3, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
// also try with setting `ignore_unlabeled`
PrecisionAtK prec = new PrecisionAtK(1, true, 10);
PrecisionAtK prec = new PrecisionAtK(true);
evaluated = prec.evaluate("id", searchHits, rated);
assertEquals((double) 2 / 2, evaluated.metricScore(), 0.00001);
assertEquals(2, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
@ -140,7 +142,7 @@ public class PrecisionAtKTests extends ESTestCase {
assertEquals(5, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
// also try with setting `ignore_unlabeled`
PrecisionAtK prec = new PrecisionAtK(1, true, 10);
PrecisionAtK prec = new PrecisionAtK(true);
evaluated = prec.evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.metricScore(), 0.00001);
assertEquals(0, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());

View File

@ -73,6 +73,7 @@ public class RankEvalSpecTests extends ESTestCase {
static RankEvalSpec createTestItem() {
Supplier<EvaluationMetric> metric = randomFrom(Arrays.asList(
() -> PrecisionAtKTests.createTestItem(),
() -> RecallAtKTests.createTestItem(),
() -> MeanReciprocalRankTests.createTestItem(),
() -> DiscountedCumulativeGainTests.createTestItem()));
@ -149,6 +150,7 @@ public class RankEvalSpecTests extends ESTestCase {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, RecallAtK.NAME, RecallAtK::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));

View File

@ -0,0 +1,249 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
import static org.elasticsearch.test.XContentTestUtils.insertRandomFields;
import static org.hamcrest.CoreMatchers.containsString;
public class RecallAtKTests extends ESTestCase {
private static final int IRRELEVANT_RATING = 0;
private static final int RELEVANT_RATING = 1;
public void testCalculation() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals(1, evaluated.metricScore(), 0.00001);
assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
}
public void testIgnoreOneResult() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
rated.add(createRatedDoc("test", "1", RELEVANT_RATING));
rated.add(createRatedDoc("test", "2", RELEVANT_RATING));
rated.add(createRatedDoc("test", "3", RELEVANT_RATING));
rated.add(createRatedDoc("test", "4", IRRELEVANT_RATING));
EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals((double) 4 / 4, evaluated.metricScore(), 0.00001);
assertEquals(4, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(4, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
}
/**
* Test that the relevant rating threshold can be set to something larger than
* 1. e.g. we set it to 2 here and expect docs 0-1 to be not relevant, docs 2-4
* to be relevant, and only 0-3 are hits.
*/
public void testRelevanceThreshold() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(createRatedDoc("test", "0", 0)); // not relevant, hit
rated.add(createRatedDoc("test", "1", 1)); // not relevant, hit
rated.add(createRatedDoc("test", "2", 2)); // relevant, hit
rated.add(createRatedDoc("test", "3", 3)); // relevant
rated.add(createRatedDoc("test", "4", 4)); // relevant
RecallAtK recallAtN = new RecallAtK(2, 5);
EvalQueryQuality evaluated = recallAtN.evaluate("id", toSearchHits(rated.subList(0,3), "test"), rated);
assertEquals((double) 1 / 3, evaluated.metricScore(), 0.00001);
assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
}
public void testCorrectIndex() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(createRatedDoc("test_other", "0", RELEVANT_RATING));
rated.add(createRatedDoc("test_other", "1", RELEVANT_RATING));
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
rated.add(createRatedDoc("test", "1", RELEVANT_RATING));
rated.add(createRatedDoc("test", "2", IRRELEVANT_RATING));
// the following search hits contain only the last three documents
List<RatedDocument> ratedSubList = rated.subList(2, 5);
EvalQueryQuality evaluated = (new RecallAtK(1, 5)).evaluate("id", toSearchHits(ratedSubList, "test"), rated);
assertEquals((double) 2 / 4, evaluated.metricScore(), 0.00001);
assertEquals(2, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(4, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
}
public void testNoRatedDocs() throws Exception {
int k = 5;
SearchHit[] hits = new SearchHit[k];
for (int i = 0; i < k; i++) {
hits[i] = new SearchHit(i, i + "", new Text(""), Collections.emptyMap());
hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0), null, OriginalIndices.NONE));
}
EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.metricScore(), 0.00001);
assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
}
public void testNoResults() throws Exception {
EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", new SearchHit[0], Collections.emptyList());
assertEquals(0.0d, evaluated.metricScore(), 0.00001);
assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
}
public void testNoResultsWithRatedDocs() throws Exception {
List<RatedDocument> rated = new ArrayList<>();
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", new SearchHit[0], rated);
assertEquals(0.0d, evaluated.metricScore(), 0.00001);
assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
}
public void testParseFromXContent() throws IOException {
String xContent = " {\n" + " \"relevant_rating_threshold\" : 2" + "}";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
RecallAtK recallAtK = RecallAtK.fromXContent(parser);
assertEquals(2, recallAtK.getRelevantRatingThreshold());
}
}
public void testCombine() {
RecallAtK metric = new RecallAtK();
List<EvalQueryQuality> partialResults = new ArrayList<>(3);
partialResults.add(new EvalQueryQuality("a", 0.1));
partialResults.add(new EvalQueryQuality("b", 0.2));
partialResults.add(new EvalQueryQuality("c", 0.6));
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
}
public void testInvalidRelevantThreshold() {
expectThrows(IllegalArgumentException.class, () -> new RecallAtK(-1, 10));
}
public void testInvalidK() {
expectThrows(IllegalArgumentException.class, () -> new RecallAtK(1, -10));
}
public static RecallAtK createTestItem() {
return new RecallAtK(randomIntBetween(0, 10), randomIntBetween(1, 50));
}
public void testXContentRoundtrip() throws IOException {
RecallAtK testItem = createTestItem();
XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS));
try (XContentParser itemParser = createParser(shuffled)) {
itemParser.nextToken();
itemParser.nextToken();
RecallAtK parsedItem = RecallAtK.fromXContent(itemParser);
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
}
public void testXContentParsingIsNotLenient() throws IOException {
RecallAtK testItem = createTestItem();
XContentType xContentType = randomFrom(XContentType.values());
BytesReference originalBytes = toShuffledXContent(testItem, xContentType, ToXContent.EMPTY_PARAMS, randomBoolean());
BytesReference withRandomFields = insertRandomFields(xContentType, originalBytes, null, random());
try (XContentParser parser = createParser(xContentType.xContent(), withRandomFields)) {
parser.nextToken();
parser.nextToken();
XContentParseException exception = expectThrows(XContentParseException.class, () -> RecallAtK.fromXContent(parser));
assertThat(exception.getMessage(), containsString("[recall] unknown field"));
}
}
public void testSerialization() throws IOException {
RecallAtK original = createTestItem();
RecallAtK deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()),
RecallAtK::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
checkEqualsAndHashCode(createTestItem(), RecallAtKTests::copy, RecallAtKTests::mutate);
}
private static RecallAtK copy(RecallAtK original) {
return new RecallAtK(original.getRelevantRatingThreshold(), original.forcedSearchSize().getAsInt());
}
private static RecallAtK mutate(RecallAtK original) {
RecallAtK recallAtK;
switch (randomIntBetween(0, 1)) {
case 0:
recallAtK = new RecallAtK(
randomValueOtherThan(original.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)),
original.forcedSearchSize().getAsInt());
break;
case 1:
recallAtK = new RecallAtK(
original.getRelevantRatingThreshold(),
original.forcedSearchSize().getAsInt() + 1);
break;
default:
throw new IllegalStateException("The test should only allow two parameters mutated");
}
return recallAtK;
}
private static SearchHit[] toSearchHits(List<RatedDocument> rated, String index) {
SearchHit[] hits = new SearchHit[rated.size()];
for (int i = 0; i < rated.size(); i++) {
hits[i] = new SearchHit(i, i + "", new Text(""), Collections.emptyMap());
hits[i].shard(new SearchShardTarget("testnode", new ShardId(index, "uuid", 0), null, OriginalIndices.NONE));
}
return hits;
}
private static RatedDocument createRatedDoc(String index, String id, int rating) {
return new RatedDocument(index, id, rating);
}
}