Register ERR metric with NamedXContentRegistry (#32320)

This adds the ERR metric to the provided xContent parsers in the module and the
high level rest client registry. Also adding integration tests to make sure the
metric is correctly registered and usable from the client.
This commit is contained in:
Christoph Büscher 2018-07-24 16:05:43 +02:00 committed by GitHub
parent 46709f1406
commit 59cf600e03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 92 additions and 12 deletions

View File

@ -22,7 +22,11 @@ package org.elasticsearch.client;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.rankeval.DiscountedCumulativeGain;
import org.elasticsearch.index.rankeval.EvalQueryQuality;
import org.elasticsearch.index.rankeval.EvaluationMetric;
import org.elasticsearch.index.rankeval.ExpectedReciprocalRank;
import org.elasticsearch.index.rankeval.MeanReciprocalRank;
import org.elasticsearch.index.rankeval.PrecisionAtK;
import org.elasticsearch.index.rankeval.RankEvalRequest;
import org.elasticsearch.index.rankeval.RankEvalResponse;
@ -35,8 +39,10 @@ import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -64,15 +70,7 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase {
* calculation where all unlabeled documents are treated as not relevant.
*/
public void testRankEvalRequest() throws IOException {
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
List<RatedDocument> amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4");
amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0"));
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery);
RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery);
List<RatedRequest> specifications = new ArrayList<>();
specifications.add(amsterdamRequest);
specifications.add(berlinRequest);
List<RatedRequest> specifications = createTestEvaluationSpec();
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
RankEvalSpec spec = new RankEvalSpec(specifications, metric);
@ -114,6 +112,38 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase {
response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync);
}
private static List<RatedRequest> createTestEvaluationSpec() {
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
List<RatedDocument> amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4");
amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0"));
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery);
RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery);
List<RatedRequest> specifications = new ArrayList<>();
specifications.add(amsterdamRequest);
specifications.add(berlinRequest);
return specifications;
}
/**
* Test case checks that the default metrics are registered and usable
*/
public void testMetrics() throws IOException {
List<RatedRequest> specifications = createTestEvaluationSpec();
List<Supplier<EvaluationMetric>> metrics = Arrays.asList(PrecisionAtK::new, MeanReciprocalRank::new, DiscountedCumulativeGain::new,
() -> new ExpectedReciprocalRank(1));
double expectedScores[] = new double[] {0.4285714285714286, 0.75, 1.6408962261063627, 0.4407738095238095};
int i = 0;
for (Supplier<EvaluationMetric> metricSupplier : metrics) {
RankEvalSpec spec = new RankEvalSpec(specifications, metricSupplier.get());
RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, new String[] { "index", "index2" });
RankEvalResponse response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync);
assertEquals(expectedScores[i], response.getMetricScore(), Double.MIN_VALUE);
i++;
}
}
private static List<RatedDocument> createRelevant(String indexName, String... docs) {
return Stream.of(docs).map(s -> new RatedDocument(indexName, s, 1)).collect(Collectors.toList());
}

View File

@ -20,6 +20,7 @@
package org.elasticsearch.client;
import com.fasterxml.jackson.core.JsonParseException;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.HttpResponse;
@ -60,6 +61,7 @@ import org.elasticsearch.common.xcontent.cbor.CborXContent;
import org.elasticsearch.common.xcontent.smile.SmileXContent;
import org.elasticsearch.index.rankeval.DiscountedCumulativeGain;
import org.elasticsearch.index.rankeval.EvaluationMetric;
import org.elasticsearch.index.rankeval.ExpectedReciprocalRank;
import org.elasticsearch.index.rankeval.MeanReciprocalRank;
import org.elasticsearch.index.rankeval.MetricDetail;
import org.elasticsearch.index.rankeval.PrecisionAtK;
@ -616,7 +618,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(8, namedXContents.size());
assertEquals(10, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -630,14 +632,16 @@ public class RestHighLevelClientTests extends ESTestCase {
assertEquals(Integer.valueOf(2), categories.get(Aggregation.class));
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
assertEquals(Integer.valueOf(3), categories.get(EvaluationMetric.class));
assertEquals(Integer.valueOf(4), categories.get(EvaluationMetric.class));
assertTrue(names.contains(PrecisionAtK.NAME));
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
assertTrue(names.contains(MeanReciprocalRank.NAME));
assertEquals(Integer.valueOf(3), categories.get(MetricDetail.class));
assertTrue(names.contains(ExpectedReciprocalRank.NAME));
assertEquals(Integer.valueOf(4), categories.get(MetricDetail.class));
assertTrue(names.contains(PrecisionAtK.NAME));
assertTrue(names.contains(MeanReciprocalRank.NAME));
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
assertTrue(names.contains(ExpectedReciprocalRank.NAME));
}
public void testApiNamingConventions() throws Exception {

View File

@ -65,6 +65,9 @@ public class ExpectedReciprocalRank implements EvaluationMetric {
public static final String NAME = "expected_reciprocal_rank";
/**
* @param maxRelevance the highest expected relevance in the data
*/
public ExpectedReciprocalRank(int maxRelevance) {
this(maxRelevance, null, DEFAULT_K);
}

View File

@ -37,12 +37,17 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider {
MeanReciprocalRank::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME),
DiscountedCumulativeGain::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(ExpectedReciprocalRank.NAME),
ExpectedReciprocalRank::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME),
PrecisionAtK.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME),
MeanReciprocalRank.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME),
DiscountedCumulativeGain.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(ExpectedReciprocalRank.NAME),
ExpectedReciprocalRank.Detail::fromXContent));
return namedXContent;
}
}

View File

@ -60,10 +60,14 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(EvaluationMetric.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(MetricDetail.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank.Detail::new));
return namedWriteables;
}

View File

@ -161,3 +161,37 @@ setup:
- match: {details.berlin_query.metric_details.mean_reciprocal_rank: {"first_relevant": 2}}
- match: {details.berlin_query.unrated_docs: [ {"_index": "foo", "_id": "doc1"}]}
---
"Expected Reciprocal Rank":
- skip:
version: " - 6.3.99"
reason: ERR was introduced in 6.4
- do:
rank_eval:
body: {
"requests" : [
{
"id": "amsterdam_query",
"request": { "query": { "match" : {"text" : "amsterdam" }}},
"ratings": [{"_index": "foo", "_id": "doc4", "rating": 1}]
},
{
"id" : "berlin_query",
"request": { "query": { "match" : { "text" : "berlin" } }, "size" : 10 },
"ratings": [{"_index": "foo", "_id": "doc4", "rating": 1}]
}
],
"metric" : {
"expected_reciprocal_rank": {
"maximum_relevance" : 1,
"k" : 5
}
}
}
- gt: {metric_score: 0.2083333}
- lt: {metric_score: 0.2083334}
- match: {details.amsterdam_query.metric_details.expected_reciprocal_rank.unrated_docs: 2}
- match: {details.berlin_query.metric_details.expected_reciprocal_rank.unrated_docs: 1}