mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-24 17:09:48 +00:00
Adds recall@k metric to rank eval API (#52889)
This change adds the recall@k metric and refactors precision@k to match the new metric. Recall@k is an important metric to use for learning to rank (LTR) use-cases. Candidate generation or first ranking phase ranking functions are often optimized for high recall, in order to generate as many relevant candidates in the top-k as possible for a second phase of ranking. Adding this metric allows tuning that base query for LTR. See: https://github.com/elastic/elasticsearch/issues/51676 Backports: https://github.com/elastic/elasticsearch/pull/52577
This commit is contained in:
parent
3c8b46a8c1
commit
68ba571f70
@ -28,6 +28,7 @@ import org.elasticsearch.index.rankeval.EvaluationMetric;
|
||||
import org.elasticsearch.index.rankeval.ExpectedReciprocalRank;
|
||||
import org.elasticsearch.index.rankeval.MeanReciprocalRank;
|
||||
import org.elasticsearch.index.rankeval.PrecisionAtK;
|
||||
import org.elasticsearch.index.rankeval.RecallAtK;
|
||||
import org.elasticsearch.index.rankeval.RankEvalRequest;
|
||||
import org.elasticsearch.index.rankeval.RankEvalResponse;
|
||||
import org.elasticsearch.index.rankeval.RankEvalSpec;
|
||||
@ -130,9 +131,9 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase {
|
||||
*/
|
||||
public void testMetrics() throws IOException {
|
||||
List<RatedRequest> specifications = createTestEvaluationSpec();
|
||||
List<Supplier<EvaluationMetric>> metrics = Arrays.asList(PrecisionAtK::new, MeanReciprocalRank::new, DiscountedCumulativeGain::new,
|
||||
() -> new ExpectedReciprocalRank(1));
|
||||
double expectedScores[] = new double[] {0.4285714285714286, 0.75, 1.6408962261063627, 0.4407738095238095};
|
||||
List<Supplier<EvaluationMetric>> metrics = Arrays.asList(PrecisionAtK::new, RecallAtK::new,
|
||||
MeanReciprocalRank::new, DiscountedCumulativeGain::new, () -> new ExpectedReciprocalRank(1));
|
||||
double expectedScores[] = new double[] {0.4285714285714286, 1.0, 0.75, 1.6408962261063627, 0.4407738095238095};
|
||||
int i = 0;
|
||||
for (Supplier<EvaluationMetric> metricSupplier : metrics) {
|
||||
RankEvalSpec spec = new RankEvalSpec(specifications, metricSupplier.get());
|
||||
|
@ -98,6 +98,7 @@ import org.elasticsearch.index.rankeval.ExpectedReciprocalRank;
|
||||
import org.elasticsearch.index.rankeval.MeanReciprocalRank;
|
||||
import org.elasticsearch.index.rankeval.MetricDetail;
|
||||
import org.elasticsearch.index.rankeval.PrecisionAtK;
|
||||
import org.elasticsearch.index.rankeval.RecallAtK;
|
||||
import org.elasticsearch.join.aggregations.ChildrenAggregationBuilder;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
@ -696,7 +697,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(57, namedXContents.size());
|
||||
assertEquals(59, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
@ -710,13 +711,15 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
assertEquals(Integer.valueOf(3), categories.get(Aggregation.class));
|
||||
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
|
||||
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
|
||||
assertEquals(Integer.valueOf(4), categories.get(EvaluationMetric.class));
|
||||
assertEquals(Integer.valueOf(5), categories.get(EvaluationMetric.class));
|
||||
assertTrue(names.contains(PrecisionAtK.NAME));
|
||||
assertTrue(names.contains(RecallAtK.NAME));
|
||||
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
|
||||
assertTrue(names.contains(MeanReciprocalRank.NAME));
|
||||
assertTrue(names.contains(ExpectedReciprocalRank.NAME));
|
||||
assertEquals(Integer.valueOf(4), categories.get(MetricDetail.class));
|
||||
assertEquals(Integer.valueOf(5), categories.get(MetricDetail.class));
|
||||
assertTrue(names.contains(PrecisionAtK.NAME));
|
||||
assertTrue(names.contains(RecallAtK.NAME));
|
||||
assertTrue(names.contains(MeanReciprocalRank.NAME));
|
||||
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
|
||||
assertTrue(names.contains(ExpectedReciprocalRank.NAME));
|
||||
|
@ -203,20 +203,21 @@ will be used. The following metrics are supported:
|
||||
[[k-precision]]
|
||||
===== Precision at K (P@k)
|
||||
|
||||
This metric measures the number of relevant results in the top k search results.
|
||||
It's a form of the well-known
|
||||
https://en.wikipedia.org/wiki/Information_retrieval#Precision[Precision] metric
|
||||
that only looks at the top k documents. It is the fraction of relevant documents
|
||||
in those first k results. A precision at 10 (P@10) value of 0.6 then means six
|
||||
out of the 10 top hits are relevant with respect to the user's information need.
|
||||
This metric measures the proportion of relevant results in the top k search results.
|
||||
It's a form of the well-known
|
||||
https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision[Precision]
|
||||
metric that only looks at the top k documents. It is the fraction of relevant
|
||||
documents in those first k results. A precision at 10 (P@10) value of 0.6 then
|
||||
means 6 out of the 10 top hits are relevant with respect to the user's
|
||||
information need.
|
||||
|
||||
P@k works well as a simple evaluation metric that has the benefit of being easy
|
||||
to understand and explain. Documents in the collection need to be rated as either
|
||||
relevant or irrelevant with respect to the current query. P@k does not take
|
||||
into account the position of the relevant documents within the top k results,
|
||||
so a ranking of ten results that contains one relevant result in position 10 is
|
||||
equally as good as a ranking of ten results that contains one relevant result
|
||||
in position 1.
|
||||
P@k works well as a simple evaluation metric that has the benefit of being easy
|
||||
to understand and explain. Documents in the collection need to be rated as either
|
||||
relevant or irrelevant with respect to the current query. P@k is a set-based
|
||||
metric and does not take into account the position of the relevant documents
|
||||
within the top k results, so a ranking of ten results that contains one
|
||||
relevant result in position 10 is equally as good as a ranking of ten results
|
||||
that contains one relevant result in position 1.
|
||||
|
||||
[source,console]
|
||||
--------------------------------
|
||||
@ -253,6 +254,58 @@ If set to 'true', unlabeled documents are ignored and neither count as relevant
|
||||
|=======================================================================
|
||||
|
||||
|
||||
[float]
|
||||
[[k-recall]]
|
||||
===== Recall at K (R@k)
|
||||
|
||||
This metric measures the total number of relevant results in the top k search
|
||||
results. It's a form of the well-known
|
||||
https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall[Recall]
|
||||
metric. It is the fraction of relevant documents in those first k results
|
||||
relative to all possible relevant results. A recall at 10 (R@10) value of 0.5 then
|
||||
means 4 out of 8 relevant documents, with respect to the user's information
|
||||
need, were retrieved in the 10 top hits.
|
||||
|
||||
R@k works well as a simple evaluation metric that has the benefit of being easy
|
||||
to understand and explain. Documents in the collection need to be rated as either
|
||||
relevant or irrelevant with respect to the current query. R@k is a set-based
|
||||
metric and does not take into account the position of the relevant documents
|
||||
within the top k results, so a ranking of ten results that contains one
|
||||
relevant result in position 10 is equally as good as a ranking of ten results
|
||||
that contains one relevant result in position 1.
|
||||
|
||||
[source,console]
|
||||
--------------------------------
|
||||
GET /twitter/_rank_eval
|
||||
{
|
||||
"requests": [
|
||||
{
|
||||
"id": "JFK query",
|
||||
"request": { "query": { "match_all": {}}},
|
||||
"ratings": []
|
||||
}],
|
||||
"metric": {
|
||||
"recall": {
|
||||
"k" : 20,
|
||||
"relevant_rating_threshold": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
--------------------------------
|
||||
// TEST[setup:twitter]
|
||||
|
||||
The `recall` metric takes the following optional parameters
|
||||
|
||||
[cols="<,<",options="header",]
|
||||
|=======================================================================
|
||||
|Parameter |Description
|
||||
|`k` |sets the maximum number of documents retrieved per query. This value will act in place of the usual `size` parameter
|
||||
in the query. Defaults to 10.
|
||||
|`relevant_rating_threshold` |sets the rating threshold above which documents are considered to be
|
||||
"relevant". Defaults to `1`.
|
||||
|=======================================================================
|
||||
|
||||
|
||||
[float]
|
||||
===== Mean reciprocal rank
|
||||
|
||||
|
@ -26,7 +26,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Details about a specific {@link EvaluationMetric} that should be included in the resonse.
|
||||
* Details about a specific {@link EvaluationMetric} that should be included in the response.
|
||||
*/
|
||||
public interface MetricDetail extends ToXContentObject, NamedWriteable {
|
||||
|
||||
|
@ -40,10 +40,10 @@ import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRati
|
||||
|
||||
/**
|
||||
* Metric implementing Precision@K
|
||||
* (https://en.wikipedia.org/wiki/Information_retrieval#Precision_at_K).<br>
|
||||
* (https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision).<br>
|
||||
* By default documents with a rating equal or bigger than 1 are considered to
|
||||
* be "relevant" for this calculation. This value can be changes using the
|
||||
* relevant_rating_threshold` parameter.<br>
|
||||
* be "relevant" for this calculation. This value can be changed using the
|
||||
* `relevant_rating_threshold` parameter.<br>
|
||||
* The `ignore_unlabeled` parameter (default to false) controls if unrated
|
||||
* documents should be ignored.
|
||||
* The `k` parameter (defaults to 10) controls the search window size.
|
||||
@ -52,19 +52,21 @@ public class PrecisionAtK implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "precision";
|
||||
|
||||
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
|
||||
private static final int DEFAULT_RELEVANT_RATING_THRESHOLD = 1;
|
||||
private static final boolean DEFAULT_IGNORE_UNLABELED = false;
|
||||
private static final int DEFAULT_K = 10;
|
||||
|
||||
private static final ParseField RELEVANT_RATING_THRESHOLD_FIELD = new ParseField("relevant_rating_threshold");
|
||||
private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled");
|
||||
private static final ParseField K_FIELD = new ParseField("k");
|
||||
|
||||
private static final int DEFAULT_K = 10;
|
||||
|
||||
private final int relevantRatingThreshold;
|
||||
private final boolean ignoreUnlabeled;
|
||||
private final int relevantRatingThreshhold;
|
||||
private final int k;
|
||||
|
||||
/**
|
||||
* Metric implementing Precision@K.
|
||||
* @param threshold
|
||||
* @param relevantRatingThreshold
|
||||
* ratings equal or above this value will be considered relevant.
|
||||
* @param ignoreUnlabeled
|
||||
* Controls how unlabeled documents in the search hits are treated.
|
||||
@ -74,53 +76,67 @@ public class PrecisionAtK implements EvaluationMetric {
|
||||
* @param k
|
||||
* controls the window size for the search results the metric takes into account
|
||||
*/
|
||||
public PrecisionAtK(int threshold, boolean ignoreUnlabeled, int k) {
|
||||
if (threshold < 0) {
|
||||
public PrecisionAtK(int relevantRatingThreshold, boolean ignoreUnlabeled, int k) {
|
||||
if (relevantRatingThreshold < 0) {
|
||||
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
|
||||
}
|
||||
if (k <= 0) {
|
||||
throw new IllegalArgumentException("Window size k must be positive.");
|
||||
}
|
||||
this.relevantRatingThreshhold = threshold;
|
||||
this.relevantRatingThreshold = relevantRatingThreshold;
|
||||
this.ignoreUnlabeled = ignoreUnlabeled;
|
||||
this.k = k;
|
||||
}
|
||||
|
||||
public PrecisionAtK() {
|
||||
this(1, false, DEFAULT_K);
|
||||
public PrecisionAtK(boolean ignoreUnlabeled) {
|
||||
this(DEFAULT_RELEVANT_RATING_THRESHOLD, ignoreUnlabeled, DEFAULT_K);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<PrecisionAtK, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
args -> {
|
||||
Integer threshHold = (Integer) args[0];
|
||||
Boolean ignoreUnlabeled = (Boolean) args[1];
|
||||
Integer k = (Integer) args[2];
|
||||
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
|
||||
ignoreUnlabeled == null ? false : ignoreUnlabeled,
|
||||
k == null ? DEFAULT_K : k);
|
||||
});
|
||||
public PrecisionAtK() {
|
||||
this(DEFAULT_RELEVANT_RATING_THRESHOLD, DEFAULT_IGNORE_UNLABELED, DEFAULT_K);
|
||||
}
|
||||
|
||||
PrecisionAtK(StreamInput in) throws IOException {
|
||||
this(in.readVInt(), in.readBoolean(), in.readVInt());
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<PrecisionAtK, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
|
||||
Integer relevantRatingThreshold = (Integer) args[0];
|
||||
Boolean ignoreUnlabeled = (Boolean) args[1];
|
||||
Integer k = (Integer) args[2];
|
||||
return new PrecisionAtK(
|
||||
relevantRatingThreshold == null ? DEFAULT_RELEVANT_RATING_THRESHOLD : relevantRatingThreshold,
|
||||
ignoreUnlabeled == null ? DEFAULT_IGNORE_UNLABELED : ignoreUnlabeled,
|
||||
k == null ? DEFAULT_K : k);
|
||||
});
|
||||
|
||||
static {
|
||||
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_THRESHOLD_FIELD);
|
||||
PARSER.declareBoolean(optionalConstructorArg(), IGNORE_UNLABELED_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
|
||||
}
|
||||
|
||||
PrecisionAtK(StreamInput in) throws IOException {
|
||||
relevantRatingThreshhold = in.readVInt();
|
||||
ignoreUnlabeled = in.readBoolean();
|
||||
k = in.readVInt();
|
||||
}
|
||||
|
||||
int getK() {
|
||||
return this.k;
|
||||
public static PrecisionAtK fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(relevantRatingThreshhold);
|
||||
out.writeBoolean(ignoreUnlabeled);
|
||||
out.writeVInt(k);
|
||||
out.writeVInt(getRelevantRatingThreshold());
|
||||
out.writeBoolean(getIgnoreUnlabeled());
|
||||
out.writeVInt(getK());
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.startObject(NAME);
|
||||
builder.field(RELEVANT_RATING_THRESHOLD_FIELD.getPreferredName(), getRelevantRatingThreshold());
|
||||
builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), getIgnoreUnlabeled());
|
||||
builder.field(K_FIELD.getPreferredName(), getK());
|
||||
builder.endObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -133,7 +149,7 @@ public class PrecisionAtK implements EvaluationMetric {
|
||||
* "relevant" for this metric. Defaults to 1.
|
||||
*/
|
||||
public int getRelevantRatingThreshold() {
|
||||
return relevantRatingThreshhold;
|
||||
return relevantRatingThreshold;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -143,61 +159,66 @@ public class PrecisionAtK implements EvaluationMetric {
|
||||
return ignoreUnlabeled;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OptionalInt forcedSearchSize() {
|
||||
return OptionalInt.of(k);
|
||||
public int getK() {
|
||||
return k;
|
||||
}
|
||||
|
||||
public static PrecisionAtK fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
@Override
|
||||
public OptionalInt forcedSearchSize() {
|
||||
return OptionalInt.of(getK());
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute precisionAtN based on provided relevant document IDs.
|
||||
* Binarizes a rating based on the relevant rating threshold.
|
||||
*/
|
||||
private boolean isRelevant(int rating) {
|
||||
return rating >= getRelevantRatingThreshold();
|
||||
}
|
||||
|
||||
/**
|
||||
* Should we count unlabeled documents? This is the inverse of {@link #getIgnoreUnlabeled()}.
|
||||
*/
|
||||
private boolean shouldCountUnlabeled() {
|
||||
return !getIgnoreUnlabeled();
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute precision at k based on provided relevant document IDs.
|
||||
*
|
||||
* @return precision at n for above {@link SearchResult} list.
|
||||
* @return precision at k for above {@link SearchResult} list.
|
||||
**/
|
||||
@Override
|
||||
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
|
||||
List<RatedDocument> ratedDocs) {
|
||||
int truePositives = 0;
|
||||
int falsePositives = 0;
|
||||
List<RatedDocument> ratedDocs) {
|
||||
|
||||
List<RatedSearchHit> ratedSearchHits = joinHitsWithRatings(hits, ratedDocs);
|
||||
|
||||
int relevantRetrieved = 0;
|
||||
int retrieved = 0;
|
||||
|
||||
for (RatedSearchHit hit : ratedSearchHits) {
|
||||
OptionalInt rating = hit.getRating();
|
||||
if (rating.isPresent()) {
|
||||
if (rating.getAsInt() >= this.relevantRatingThreshhold) {
|
||||
truePositives++;
|
||||
} else {
|
||||
falsePositives++;
|
||||
retrieved++;
|
||||
if (isRelevant(rating.getAsInt())) {
|
||||
relevantRetrieved++;
|
||||
}
|
||||
} else if (ignoreUnlabeled == false) {
|
||||
falsePositives++;
|
||||
} else if (shouldCountUnlabeled()) {
|
||||
retrieved++;
|
||||
}
|
||||
}
|
||||
|
||||
double precision = 0.0;
|
||||
if (truePositives + falsePositives > 0) {
|
||||
precision = (double) truePositives / (truePositives + falsePositives);
|
||||
if (retrieved > 0) {
|
||||
precision = (double) relevantRetrieved / retrieved;
|
||||
}
|
||||
|
||||
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, precision);
|
||||
evalQueryQuality.setMetricDetails(
|
||||
new PrecisionAtK.Detail(truePositives, truePositives + falsePositives));
|
||||
evalQueryQuality.setMetricDetails(new PrecisionAtK.Detail(relevantRetrieved, retrieved));
|
||||
evalQueryQuality.addHitsAndRatings(ratedSearchHits);
|
||||
return evalQueryQuality;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.startObject(NAME);
|
||||
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
|
||||
builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.ignoreUnlabeled);
|
||||
builder.field(K_FIELD.getPreferredName(), this.k);
|
||||
builder.endObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
@ -207,20 +228,21 @@ public class PrecisionAtK implements EvaluationMetric {
|
||||
return false;
|
||||
}
|
||||
PrecisionAtK other = (PrecisionAtK) obj;
|
||||
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold)
|
||||
&& Objects.equals(k, other.k)
|
||||
&& Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled);
|
||||
return Objects.equals(relevantRatingThreshold, other.relevantRatingThreshold)
|
||||
&& Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled)
|
||||
&& Objects.equals(k, other.k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled, k);
|
||||
return Objects.hash(relevantRatingThreshold, ignoreUnlabeled, k);
|
||||
}
|
||||
|
||||
public static final class Detail implements MetricDetail {
|
||||
|
||||
private static final ParseField DOCS_RETRIEVED_FIELD = new ParseField("docs_retrieved");
|
||||
private static final ParseField RELEVANT_DOCS_RETRIEVED_FIELD = new ParseField("relevant_docs_retrieved");
|
||||
private static final ParseField DOCS_RETRIEVED_FIELD = new ParseField("docs_retrieved");
|
||||
|
||||
private int relevantRetrieved;
|
||||
private int retrieved;
|
||||
|
||||
@ -230,21 +252,11 @@ public class PrecisionAtK implements EvaluationMetric {
|
||||
}
|
||||
|
||||
Detail(StreamInput in) throws IOException {
|
||||
this.relevantRetrieved = in.readVInt();
|
||||
this.retrieved = in.readVInt();
|
||||
this(in.readVInt(), in.readVInt());
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder innerToXContent(XContentBuilder builder, Params params)
|
||||
throws IOException {
|
||||
builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved);
|
||||
builder.field(DOCS_RETRIEVED_FIELD.getPreferredName(), retrieved);
|
||||
return builder;
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<Detail, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
|
||||
return new Detail((Integer) args[0], (Integer) args[1]);
|
||||
});
|
||||
private static final ConstructingObjectParser<Detail, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME, true, args -> new Detail((Integer) args[0], (Integer) args[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareInt(constructorArg(), RELEVANT_DOCS_RETRIEVED_FIELD);
|
||||
@ -257,8 +269,16 @@ public class PrecisionAtK implements EvaluationMetric {
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(relevantRetrieved);
|
||||
out.writeVInt(retrieved);
|
||||
out.writeVLong(relevantRetrieved);
|
||||
out.writeVLong(retrieved);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder innerToXContent(XContentBuilder builder, Params params)
|
||||
throws IOException {
|
||||
builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved);
|
||||
builder.field(DOCS_RETRIEVED_FIELD.getPreferredName(), retrieved);
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -31,8 +31,10 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider {
|
||||
@Override
|
||||
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.add(
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME),
|
||||
PrecisionAtK::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallAtK.NAME),
|
||||
RecallAtK::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanReciprocalRank.NAME),
|
||||
MeanReciprocalRank::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME),
|
||||
@ -42,6 +44,8 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider {
|
||||
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME),
|
||||
PrecisionAtK.Detail::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(RecallAtK.NAME),
|
||||
RecallAtK.Detail::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME),
|
||||
MeanReciprocalRank.Detail::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME),
|
||||
|
@ -58,12 +58,14 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
|
||||
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, RecallAtK.NAME, RecallAtK::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, RecallAtK.NAME, RecallAtK.Detail::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new));
|
||||
|
@ -0,0 +1,280 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.OptionalInt;
|
||||
|
||||
import javax.naming.directory.SearchResult;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
|
||||
|
||||
/**
|
||||
* Metric implementing Recall@K
|
||||
* (https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall).<br>
|
||||
* By default documents with a rating equal or bigger than 1 are considered to
|
||||
* be "relevant" for this calculation. This value can be changed using the
|
||||
* `relevant_rating_threshold` parameter.<br>
|
||||
* The `k` parameter (defaults to 10) controls the search window size.
|
||||
*/
|
||||
public class RecallAtK implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "recall";
|
||||
|
||||
private static final int DEFAULT_RELEVANT_RATING_THRESHOLD = 1;
|
||||
private static final int DEFAULT_K = 10;
|
||||
|
||||
private static final ParseField RELEVANT_RATING_THRESHOLD_FIELD = new ParseField("relevant_rating_threshold");
|
||||
private static final ParseField K_FIELD = new ParseField("k");
|
||||
|
||||
private final int relevantRatingThreshold;
|
||||
private final int k;
|
||||
|
||||
/**
|
||||
* Metric implementing Recall@K.
|
||||
* @param relevantRatingThreshold
|
||||
* ratings equal or above this value will be considered relevant.
|
||||
* @param k
|
||||
* controls the window size for the search results the metric takes into account
|
||||
*/
|
||||
public RecallAtK(int relevantRatingThreshold, int k) {
|
||||
if (relevantRatingThreshold < 0) {
|
||||
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
|
||||
}
|
||||
if (k <= 0) {
|
||||
throw new IllegalArgumentException("Window size k must be positive.");
|
||||
}
|
||||
this.relevantRatingThreshold = relevantRatingThreshold;
|
||||
this.k = k;
|
||||
}
|
||||
|
||||
public RecallAtK() {
|
||||
this(DEFAULT_RELEVANT_RATING_THRESHOLD, DEFAULT_K);
|
||||
}
|
||||
|
||||
RecallAtK(StreamInput in) throws IOException {
|
||||
this(in.readVInt(), in.readVInt());
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<RecallAtK, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
|
||||
Integer relevantRatingThreshold = (Integer) args[0];
|
||||
Integer k = (Integer) args[1];
|
||||
return new RecallAtK(
|
||||
relevantRatingThreshold == null ? DEFAULT_RELEVANT_RATING_THRESHOLD : relevantRatingThreshold,
|
||||
k == null ? DEFAULT_K : k);
|
||||
});
|
||||
|
||||
static {
|
||||
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_THRESHOLD_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
|
||||
}
|
||||
|
||||
public static RecallAtK fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(getRelevantRatingThreshold());
|
||||
out.writeVInt(getK());
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.startObject(NAME);
|
||||
builder.field(RELEVANT_RATING_THRESHOLD_FIELD.getPreferredName(), getRelevantRatingThreshold());
|
||||
builder.field(K_FIELD.getPreferredName(), getK());
|
||||
builder.endObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the rating threshold above which ratings are considered to be
|
||||
* "relevant" for this metric. Defaults to 1.
|
||||
*/
|
||||
public int getRelevantRatingThreshold() {
|
||||
return relevantRatingThreshold;
|
||||
}
|
||||
|
||||
public int getK() {
|
||||
return k;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OptionalInt forcedSearchSize() {
|
||||
return OptionalInt.of(getK());
|
||||
}
|
||||
|
||||
/**
|
||||
* Binarizes a rating based on the relevant rating threshold.
|
||||
*/
|
||||
private boolean isRelevant(int rating) {
|
||||
return rating >= getRelevantRatingThreshold();
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute recall at k based on provided relevant document IDs.
|
||||
*
|
||||
* @return recall at k for above {@link SearchResult} list.
|
||||
**/
|
||||
@Override
|
||||
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
|
||||
List<RatedDocument> ratedDocs) {
|
||||
|
||||
List<RatedSearchHit> ratedSearchHits = joinHitsWithRatings(hits, ratedDocs);
|
||||
|
||||
int relevantRetrieved = 0;
|
||||
for (RatedSearchHit hit : ratedSearchHits) {
|
||||
OptionalInt rating = hit.getRating();
|
||||
if (rating.isPresent() && isRelevant(rating.getAsInt())) {
|
||||
relevantRetrieved++;
|
||||
}
|
||||
}
|
||||
|
||||
int relevant = 0;
|
||||
for (RatedDocument rd : ratedDocs) {
|
||||
if(isRelevant(rd.getRating())) {
|
||||
relevant++;
|
||||
}
|
||||
}
|
||||
|
||||
double recall = 0.0;
|
||||
if (relevant > 0) {
|
||||
recall = (double) relevantRetrieved / relevant;
|
||||
}
|
||||
|
||||
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, recall);
|
||||
evalQueryQuality.setMetricDetails(new RecallAtK.Detail(relevantRetrieved, relevant));
|
||||
evalQueryQuality.addHitsAndRatings(ratedSearchHits);
|
||||
return evalQueryQuality;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
RecallAtK other = (RecallAtK) obj;
|
||||
return Objects.equals(relevantRatingThreshold, other.relevantRatingThreshold)
|
||||
&& Objects.equals(k, other.k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(relevantRatingThreshold, k);
|
||||
}
|
||||
|
||||
public static final class Detail implements MetricDetail {
|
||||
|
||||
private static final ParseField RELEVANT_DOCS_RETRIEVED_FIELD = new ParseField("relevant_docs_retrieved");
|
||||
private static final ParseField RELEVANT_DOCS_FIELD = new ParseField("relevant_docs");
|
||||
private long relevantRetrieved;
|
||||
private long relevant;
|
||||
|
||||
Detail(long relevantRetrieved, long relevant) {
|
||||
this.relevantRetrieved = relevantRetrieved;
|
||||
this.relevant = relevant;
|
||||
}
|
||||
|
||||
Detail(StreamInput in) throws IOException {
|
||||
this.relevantRetrieved = in.readVLong();
|
||||
this.relevant = in.readVLong();
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<Detail, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME, true, args -> new Detail((Integer) args[0], (Integer) args[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareInt(constructorArg(), RELEVANT_DOCS_RETRIEVED_FIELD);
|
||||
PARSER.declareInt(constructorArg(), RELEVANT_DOCS_FIELD);
|
||||
}
|
||||
|
||||
public static Detail fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVLong(relevantRetrieved);
|
||||
out.writeVLong(relevant);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder innerToXContent(XContentBuilder builder, Params params)
|
||||
throws IOException {
|
||||
builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved);
|
||||
builder.field(RELEVANT_DOCS_FIELD.getPreferredName(), relevant);
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public long getRelevantRetrieved() {
|
||||
return relevantRetrieved;
|
||||
}
|
||||
|
||||
public long getRelevant() {
|
||||
return relevant;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
RecallAtK.Detail other = (RecallAtK.Detail) obj;
|
||||
return Objects.equals(relevantRetrieved, other.relevantRetrieved)
|
||||
&& Objects.equals(relevant, other.relevant);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(relevantRetrieved, relevant);
|
||||
}
|
||||
}
|
||||
}
|
@ -48,25 +48,25 @@ import static org.hamcrest.CoreMatchers.containsString;
|
||||
|
||||
public class PrecisionAtKTests extends ESTestCase {
|
||||
|
||||
private static final int IRRELEVANT_RATING_0 = 0;
|
||||
private static final int RELEVANT_RATING_1 = 1;
|
||||
private static final int IRRELEVANT_RATING = 0;
|
||||
private static final int RELEVANT_RATING = 1;
|
||||
|
||||
public void testPrecisionAtFiveCalculation() {
|
||||
public void testCalculation() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
|
||||
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
|
||||
assertEquals(1, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(1, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(1, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
|
||||
}
|
||||
|
||||
public void testPrecisionAtFiveIgnoreOneResult() {
|
||||
public void testIgnoreOneResult() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test", "1", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test", "2", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test", "3", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test", "4", IRRELEVANT_RATING_0));
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "1", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "2", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "3", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "4", IRRELEVANT_RATING));
|
||||
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
|
||||
assertEquals((double) 4 / 5, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(4, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
@ -75,10 +75,10 @@ public class PrecisionAtKTests extends ESTestCase {
|
||||
|
||||
/**
|
||||
* test that the relevant rating threshold can be set to something larger than
|
||||
* 1. e.g. we set it to 2 here and expect dics 0-2 to be not relevant, doc 3 and
|
||||
* 4 to be relevant
|
||||
* 1. e.g. we set it to 2 here and expect docs 0-1 to be not relevant, docs 2-4
|
||||
* to be relevant
|
||||
*/
|
||||
public void testPrecisionAtFiveRelevanceThreshold() {
|
||||
public void testRelevanceThreshold() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(createRatedDoc("test", "0", 0));
|
||||
rated.add(createRatedDoc("test", "1", 1));
|
||||
@ -94,13 +94,15 @@ public class PrecisionAtKTests extends ESTestCase {
|
||||
|
||||
public void testPrecisionAtFiveCorrectIndex() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(createRatedDoc("test_other", "0", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test_other", "1", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test", "1", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test", "2", IRRELEVANT_RATING_0));
|
||||
rated.add(createRatedDoc("test_other", "0", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test_other", "1", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "1", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "2", IRRELEVANT_RATING));
|
||||
// the following search hits contain only the last three documents
|
||||
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated.subList(2, 5), "test"), rated);
|
||||
List<RatedDocument> ratedSubList = rated.subList(2, 5);
|
||||
PrecisionAtK precisionAtK = new PrecisionAtK(1, false, 5);
|
||||
EvalQueryQuality evaluated = (precisionAtK).evaluate("id", toSearchHits(ratedSubList, "test"), rated);
|
||||
assertEquals((double) 2 / 3, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(2, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(3, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
|
||||
@ -108,8 +110,8 @@ public class PrecisionAtKTests extends ESTestCase {
|
||||
|
||||
public void testIgnoreUnlabeled() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test", "1", RELEVANT_RATING_1));
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "1", RELEVANT_RATING));
|
||||
// add an unlabeled search hit
|
||||
SearchHit[] searchHits = Arrays.copyOf(toSearchHits(rated, "test"), 3);
|
||||
searchHits[2] = new SearchHit(2, "2", new Text(MapperService.SINGLE_MAPPING_NAME), Collections.emptyMap());
|
||||
@ -121,7 +123,7 @@ public class PrecisionAtKTests extends ESTestCase {
|
||||
assertEquals(3, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
|
||||
|
||||
// also try with setting `ignore_unlabeled`
|
||||
PrecisionAtK prec = new PrecisionAtK(1, true, 10);
|
||||
PrecisionAtK prec = new PrecisionAtK(true);
|
||||
evaluated = prec.evaluate("id", searchHits, rated);
|
||||
assertEquals((double) 2 / 2, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(2, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
@ -140,7 +142,7 @@ public class PrecisionAtKTests extends ESTestCase {
|
||||
assertEquals(5, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
|
||||
|
||||
// also try with setting `ignore_unlabeled`
|
||||
PrecisionAtK prec = new PrecisionAtK(1, true, 10);
|
||||
PrecisionAtK prec = new PrecisionAtK(true);
|
||||
evaluated = prec.evaluate("id", hits, Collections.emptyList());
|
||||
assertEquals(0.0d, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(0, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
|
@ -73,6 +73,7 @@ public class RankEvalSpecTests extends ESTestCase {
|
||||
static RankEvalSpec createTestItem() {
|
||||
Supplier<EvaluationMetric> metric = randomFrom(Arrays.asList(
|
||||
() -> PrecisionAtKTests.createTestItem(),
|
||||
() -> RecallAtKTests.createTestItem(),
|
||||
() -> MeanReciprocalRankTests.createTestItem(),
|
||||
() -> DiscountedCumulativeGainTests.createTestItem()));
|
||||
|
||||
@ -149,6 +150,7 @@ public class RankEvalSpecTests extends ESTestCase {
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, RecallAtK.NAME, RecallAtK::new));
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
|
||||
|
@ -0,0 +1,249 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.action.OriginalIndices;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.index.shard.ShardId;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchShardTarget;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
|
||||
import static org.elasticsearch.test.XContentTestUtils.insertRandomFields;
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
|
||||
public class RecallAtKTests extends ESTestCase {
|
||||
|
||||
private static final int IRRELEVANT_RATING = 0;
|
||||
private static final int RELEVANT_RATING = 1;
|
||||
|
||||
public void testCalculation() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
|
||||
|
||||
EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
|
||||
assertEquals(1, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
|
||||
}
|
||||
|
||||
public void testIgnoreOneResult() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "1", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "2", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "3", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "4", IRRELEVANT_RATING));
|
||||
|
||||
EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
|
||||
assertEquals((double) 4 / 4, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(4, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(4, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
|
||||
}
|
||||
|
||||
/**
|
||||
* Test that the relevant rating threshold can be set to something larger than
|
||||
* 1. e.g. we set it to 2 here and expect docs 0-1 to be not relevant, docs 2-4
|
||||
* to be relevant, and only 0-3 are hits.
|
||||
*/
|
||||
public void testRelevanceThreshold() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(createRatedDoc("test", "0", 0)); // not relevant, hit
|
||||
rated.add(createRatedDoc("test", "1", 1)); // not relevant, hit
|
||||
rated.add(createRatedDoc("test", "2", 2)); // relevant, hit
|
||||
rated.add(createRatedDoc("test", "3", 3)); // relevant
|
||||
rated.add(createRatedDoc("test", "4", 4)); // relevant
|
||||
|
||||
RecallAtK recallAtN = new RecallAtK(2, 5);
|
||||
|
||||
EvalQueryQuality evaluated = recallAtN.evaluate("id", toSearchHits(rated.subList(0,3), "test"), rated);
|
||||
assertEquals((double) 1 / 3, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(3, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
|
||||
}
|
||||
|
||||
public void testCorrectIndex() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(createRatedDoc("test_other", "0", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test_other", "1", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "1", RELEVANT_RATING));
|
||||
rated.add(createRatedDoc("test", "2", IRRELEVANT_RATING));
|
||||
|
||||
// the following search hits contain only the last three documents
|
||||
List<RatedDocument> ratedSubList = rated.subList(2, 5);
|
||||
|
||||
EvalQueryQuality evaluated = (new RecallAtK(1, 5)).evaluate("id", toSearchHits(ratedSubList, "test"), rated);
|
||||
assertEquals((double) 2 / 4, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(2, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(4, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
|
||||
}
|
||||
|
||||
public void testNoRatedDocs() throws Exception {
|
||||
int k = 5;
|
||||
SearchHit[] hits = new SearchHit[k];
|
||||
for (int i = 0; i < k; i++) {
|
||||
hits[i] = new SearchHit(i, i + "", new Text(""), Collections.emptyMap());
|
||||
hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0), null, OriginalIndices.NONE));
|
||||
}
|
||||
|
||||
EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", hits, Collections.emptyList());
|
||||
assertEquals(0.0d, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
|
||||
}
|
||||
|
||||
public void testNoResults() throws Exception {
|
||||
EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", new SearchHit[0], Collections.emptyList());
|
||||
assertEquals(0.0d, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
|
||||
}
|
||||
|
||||
public void testNoResultsWithRatedDocs() throws Exception {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(createRatedDoc("test", "0", RELEVANT_RATING));
|
||||
|
||||
EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", new SearchHit[0], rated);
|
||||
assertEquals(0.0d, evaluated.metricScore(), 0.00001);
|
||||
assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant());
|
||||
}
|
||||
|
||||
public void testParseFromXContent() throws IOException {
|
||||
String xContent = " {\n" + " \"relevant_rating_threshold\" : 2" + "}";
|
||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||
RecallAtK recallAtK = RecallAtK.fromXContent(parser);
|
||||
assertEquals(2, recallAtK.getRelevantRatingThreshold());
|
||||
}
|
||||
}
|
||||
|
||||
public void testCombine() {
|
||||
RecallAtK metric = new RecallAtK();
|
||||
List<EvalQueryQuality> partialResults = new ArrayList<>(3);
|
||||
partialResults.add(new EvalQueryQuality("a", 0.1));
|
||||
partialResults.add(new EvalQueryQuality("b", 0.2));
|
||||
partialResults.add(new EvalQueryQuality("c", 0.6));
|
||||
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
|
||||
}
|
||||
|
||||
public void testInvalidRelevantThreshold() {
|
||||
expectThrows(IllegalArgumentException.class, () -> new RecallAtK(-1, 10));
|
||||
}
|
||||
|
||||
public void testInvalidK() {
|
||||
expectThrows(IllegalArgumentException.class, () -> new RecallAtK(1, -10));
|
||||
}
|
||||
|
||||
public static RecallAtK createTestItem() {
|
||||
return new RecallAtK(randomIntBetween(0, 10), randomIntBetween(1, 50));
|
||||
}
|
||||
|
||||
public void testXContentRoundtrip() throws IOException {
|
||||
RecallAtK testItem = createTestItem();
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
|
||||
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS));
|
||||
try (XContentParser itemParser = createParser(shuffled)) {
|
||||
itemParser.nextToken();
|
||||
itemParser.nextToken();
|
||||
RecallAtK parsedItem = RecallAtK.fromXContent(itemParser);
|
||||
assertNotSame(testItem, parsedItem);
|
||||
assertEquals(testItem, parsedItem);
|
||||
assertEquals(testItem.hashCode(), parsedItem.hashCode());
|
||||
}
|
||||
}
|
||||
|
||||
public void testXContentParsingIsNotLenient() throws IOException {
|
||||
RecallAtK testItem = createTestItem();
|
||||
XContentType xContentType = randomFrom(XContentType.values());
|
||||
BytesReference originalBytes = toShuffledXContent(testItem, xContentType, ToXContent.EMPTY_PARAMS, randomBoolean());
|
||||
BytesReference withRandomFields = insertRandomFields(xContentType, originalBytes, null, random());
|
||||
try (XContentParser parser = createParser(xContentType.xContent(), withRandomFields)) {
|
||||
parser.nextToken();
|
||||
parser.nextToken();
|
||||
XContentParseException exception = expectThrows(XContentParseException.class, () -> RecallAtK.fromXContent(parser));
|
||||
assertThat(exception.getMessage(), containsString("[recall] unknown field"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
RecallAtK original = createTestItem();
|
||||
RecallAtK deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()),
|
||||
RecallAtK::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
checkEqualsAndHashCode(createTestItem(), RecallAtKTests::copy, RecallAtKTests::mutate);
|
||||
}
|
||||
|
||||
private static RecallAtK copy(RecallAtK original) {
|
||||
return new RecallAtK(original.getRelevantRatingThreshold(), original.forcedSearchSize().getAsInt());
|
||||
}
|
||||
|
||||
private static RecallAtK mutate(RecallAtK original) {
|
||||
RecallAtK recallAtK;
|
||||
switch (randomIntBetween(0, 1)) {
|
||||
case 0:
|
||||
recallAtK = new RecallAtK(
|
||||
randomValueOtherThan(original.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)),
|
||||
original.forcedSearchSize().getAsInt());
|
||||
break;
|
||||
case 1:
|
||||
recallAtK = new RecallAtK(
|
||||
original.getRelevantRatingThreshold(),
|
||||
original.forcedSearchSize().getAsInt() + 1);
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("The test should only allow two parameters mutated");
|
||||
}
|
||||
return recallAtK;
|
||||
}
|
||||
|
||||
private static SearchHit[] toSearchHits(List<RatedDocument> rated, String index) {
|
||||
SearchHit[] hits = new SearchHit[rated.size()];
|
||||
for (int i = 0; i < rated.size(); i++) {
|
||||
hits[i] = new SearchHit(i, i + "", new Text(""), Collections.emptyMap());
|
||||
hits[i].shard(new SearchShardTarget("testnode", new ShardId(index, "uuid", 0), null, OriginalIndices.NONE));
|
||||
}
|
||||
return hits;
|
||||
}
|
||||
|
||||
private static RatedDocument createRatedDoc(String index, String id, int rating) {
|
||||
return new RatedDocument(index, id, rating);
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user