From 59cf600e03f8148434655f2e2e23d88b52c20712 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Tue, 24 Jul 2018 16:05:43 +0200 Subject: [PATCH] Register ERR metric with NamedXContentRegistry (#32320) This adds the ERR metric to the provided xContent parsers in the module and the high level rest client registry. Also adding integration tests to make sure the metric is correctly registered and usable from the client. --- .../org/elasticsearch/client/RankEvalIT.java | 48 +++++++++++++++---- .../client/RestHighLevelClientTests.java | 10 ++-- .../rankeval/ExpectedReciprocalRank.java | 3 ++ .../RankEvalNamedXContentProvider.java | 5 ++ .../index/rankeval/RankEvalPlugin.java | 4 ++ .../rest-api-spec/test/rank_eval/10_basic.yml | 34 +++++++++++++ 6 files changed, 92 insertions(+), 12 deletions(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java index 330afafd9ef..0af270cb051 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java @@ -22,7 +22,11 @@ package org.elasticsearch.client; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.rankeval.DiscountedCumulativeGain; import org.elasticsearch.index.rankeval.EvalQueryQuality; +import org.elasticsearch.index.rankeval.EvaluationMetric; +import org.elasticsearch.index.rankeval.ExpectedReciprocalRank; +import org.elasticsearch.index.rankeval.MeanReciprocalRank; import org.elasticsearch.index.rankeval.PrecisionAtK; import org.elasticsearch.index.rankeval.RankEvalRequest; import org.elasticsearch.index.rankeval.RankEvalResponse; @@ -35,8 +39,10 @@ import org.junit.Before; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -64,15 +70,7 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase { * calculation where all unlabeled documents are treated as not relevant. */ public void testRankEvalRequest() throws IOException { - SearchSourceBuilder testQuery = new SearchSourceBuilder(); - testQuery.query(new MatchAllQueryBuilder()); - List amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4"); - amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0")); - RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery); - RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery); - List specifications = new ArrayList<>(); - specifications.add(amsterdamRequest); - specifications.add(berlinRequest); + List specifications = createTestEvaluationSpec(); PrecisionAtK metric = new PrecisionAtK(1, false, 10); RankEvalSpec spec = new RankEvalSpec(specifications, metric); @@ -114,6 +112,38 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase { response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync); } + private static List createTestEvaluationSpec() { + SearchSourceBuilder testQuery = new SearchSourceBuilder(); + testQuery.query(new MatchAllQueryBuilder()); + List amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4"); + amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0")); + RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery); + RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery); + List specifications = new ArrayList<>(); + specifications.add(amsterdamRequest); + specifications.add(berlinRequest); + return specifications; + } + + /** + * Test case checks that the default metrics are registered and usable + */ + public void testMetrics() throws IOException { + List specifications = createTestEvaluationSpec(); + List> metrics = Arrays.asList(PrecisionAtK::new, MeanReciprocalRank::new, DiscountedCumulativeGain::new, + () -> new ExpectedReciprocalRank(1)); + double expectedScores[] = new double[] {0.4285714285714286, 0.75, 1.6408962261063627, 0.4407738095238095}; + int i = 0; + for (Supplier metricSupplier : metrics) { + RankEvalSpec spec = new RankEvalSpec(specifications, metricSupplier.get()); + + RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, new String[] { "index", "index2" }); + RankEvalResponse response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync); + assertEquals(expectedScores[i], response.getMetricScore(), Double.MIN_VALUE); + i++; + } + } + private static List createRelevant(String indexName, String... docs) { return Stream.of(docs).map(s -> new RatedDocument(indexName, s, 1)).collect(Collectors.toList()); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 64a344790ca..48934a9bed8 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.client; import com.fasterxml.jackson.core.JsonParseException; + import org.apache.http.HttpEntity; import org.apache.http.HttpHost; import org.apache.http.HttpResponse; @@ -60,6 +61,7 @@ import org.elasticsearch.common.xcontent.cbor.CborXContent; import org.elasticsearch.common.xcontent.smile.SmileXContent; import org.elasticsearch.index.rankeval.DiscountedCumulativeGain; import org.elasticsearch.index.rankeval.EvaluationMetric; +import org.elasticsearch.index.rankeval.ExpectedReciprocalRank; import org.elasticsearch.index.rankeval.MeanReciprocalRank; import org.elasticsearch.index.rankeval.MetricDetail; import org.elasticsearch.index.rankeval.PrecisionAtK; @@ -616,7 +618,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(8, namedXContents.size()); + assertEquals(10, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -630,14 +632,16 @@ public class RestHighLevelClientTests extends ESTestCase { assertEquals(Integer.valueOf(2), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); - assertEquals(Integer.valueOf(3), categories.get(EvaluationMetric.class)); + assertEquals(Integer.valueOf(4), categories.get(EvaluationMetric.class)); assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); - assertEquals(Integer.valueOf(3), categories.get(MetricDetail.class)); + assertTrue(names.contains(ExpectedReciprocalRank.NAME)); + assertEquals(Integer.valueOf(4), categories.get(MetricDetail.class)); assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); + assertTrue(names.contains(ExpectedReciprocalRank.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java index 4aac29f299d..39e1266504d 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java @@ -65,6 +65,9 @@ public class ExpectedReciprocalRank implements EvaluationMetric { public static final String NAME = "expected_reciprocal_rank"; + /** + * @param maxRelevance the highest expected relevance in the data + */ public ExpectedReciprocalRank(int maxRelevance) { this(maxRelevance, null, DEFAULT_K); } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java index f2176113cdf..7eddcf9dff6 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java @@ -37,12 +37,17 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider { MeanReciprocalRank::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME), DiscountedCumulativeGain::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(ExpectedReciprocalRank.NAME), + ExpectedReciprocalRank::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK.Detail::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME), MeanReciprocalRank.Detail::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME), DiscountedCumulativeGain.Detail::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(ExpectedReciprocalRank.NAME), + ExpectedReciprocalRank.Detail::fromXContent)); return namedXContent; } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java index 8ac2b7fbee5..0e5d754778f 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java @@ -60,10 +60,14 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin { namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(EvaluationMetric.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(MetricDetail.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank.Detail::new)); return namedWriteables; } diff --git a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml index fe877b37a68..ebe23ae53f4 100644 --- a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml +++ b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml @@ -161,3 +161,37 @@ setup: - match: {details.berlin_query.metric_details.mean_reciprocal_rank: {"first_relevant": 2}} - match: {details.berlin_query.unrated_docs: [ {"_index": "foo", "_id": "doc1"}]} +--- +"Expected Reciprocal Rank": + + - skip: + version: " - 6.3.99" + reason: ERR was introduced in 6.4 + + - do: + rank_eval: + body: { + "requests" : [ + { + "id": "amsterdam_query", + "request": { "query": { "match" : {"text" : "amsterdam" }}}, + "ratings": [{"_index": "foo", "_id": "doc4", "rating": 1}] + }, + { + "id" : "berlin_query", + "request": { "query": { "match" : { "text" : "berlin" } }, "size" : 10 }, + "ratings": [{"_index": "foo", "_id": "doc4", "rating": 1}] + } + ], + "metric" : { + "expected_reciprocal_rank": { + "maximum_relevance" : 1, + "k" : 5 + } + } + } + + - gt: {metric_score: 0.2083333} + - lt: {metric_score: 0.2083334} + - match: {details.amsterdam_query.metric_details.expected_reciprocal_rank.unrated_docs: 2} + - match: {details.berlin_query.metric_details.expected_reciprocal_rank.unrated_docs: 1}