Register ERR metric with NamedXContentRegistry (#32320)
This adds the ERR metric to the provided xContent parsers in the module and the high level rest client registry. Also adding integration tests to make sure the metric is correctly registered and usable from the client.
This commit is contained in:
parent
46709f1406
commit
59cf600e03
|
@ -22,7 +22,11 @@ package org.elasticsearch.client;
|
|||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.rankeval.DiscountedCumulativeGain;
|
||||
import org.elasticsearch.index.rankeval.EvalQueryQuality;
|
||||
import org.elasticsearch.index.rankeval.EvaluationMetric;
|
||||
import org.elasticsearch.index.rankeval.ExpectedReciprocalRank;
|
||||
import org.elasticsearch.index.rankeval.MeanReciprocalRank;
|
||||
import org.elasticsearch.index.rankeval.PrecisionAtK;
|
||||
import org.elasticsearch.index.rankeval.RankEvalRequest;
|
||||
import org.elasticsearch.index.rankeval.RankEvalResponse;
|
||||
|
@ -35,8 +39,10 @@ import org.junit.Before;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
@ -64,15 +70,7 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase {
|
|||
* calculation where all unlabeled documents are treated as not relevant.
|
||||
*/
|
||||
public void testRankEvalRequest() throws IOException {
|
||||
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
||||
testQuery.query(new MatchAllQueryBuilder());
|
||||
List<RatedDocument> amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4");
|
||||
amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0"));
|
||||
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery);
|
||||
RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery);
|
||||
List<RatedRequest> specifications = new ArrayList<>();
|
||||
specifications.add(amsterdamRequest);
|
||||
specifications.add(berlinRequest);
|
||||
List<RatedRequest> specifications = createTestEvaluationSpec();
|
||||
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
|
||||
RankEvalSpec spec = new RankEvalSpec(specifications, metric);
|
||||
|
||||
|
@ -114,6 +112,38 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase {
|
|||
response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync);
|
||||
}
|
||||
|
||||
private static List<RatedRequest> createTestEvaluationSpec() {
|
||||
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
||||
testQuery.query(new MatchAllQueryBuilder());
|
||||
List<RatedDocument> amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4");
|
||||
amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0"));
|
||||
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery);
|
||||
RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery);
|
||||
List<RatedRequest> specifications = new ArrayList<>();
|
||||
specifications.add(amsterdamRequest);
|
||||
specifications.add(berlinRequest);
|
||||
return specifications;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test case checks that the default metrics are registered and usable
|
||||
*/
|
||||
public void testMetrics() throws IOException {
|
||||
List<RatedRequest> specifications = createTestEvaluationSpec();
|
||||
List<Supplier<EvaluationMetric>> metrics = Arrays.asList(PrecisionAtK::new, MeanReciprocalRank::new, DiscountedCumulativeGain::new,
|
||||
() -> new ExpectedReciprocalRank(1));
|
||||
double expectedScores[] = new double[] {0.4285714285714286, 0.75, 1.6408962261063627, 0.4407738095238095};
|
||||
int i = 0;
|
||||
for (Supplier<EvaluationMetric> metricSupplier : metrics) {
|
||||
RankEvalSpec spec = new RankEvalSpec(specifications, metricSupplier.get());
|
||||
|
||||
RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, new String[] { "index", "index2" });
|
||||
RankEvalResponse response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync);
|
||||
assertEquals(expectedScores[i], response.getMetricScore(), Double.MIN_VALUE);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
private static List<RatedDocument> createRelevant(String indexName, String... docs) {
|
||||
return Stream.of(docs).map(s -> new RatedDocument(indexName, s, 1)).collect(Collectors.toList());
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
package org.elasticsearch.client;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonParseException;
|
||||
|
||||
import org.apache.http.HttpEntity;
|
||||
import org.apache.http.HttpHost;
|
||||
import org.apache.http.HttpResponse;
|
||||
|
@ -60,6 +61,7 @@ import org.elasticsearch.common.xcontent.cbor.CborXContent;
|
|||
import org.elasticsearch.common.xcontent.smile.SmileXContent;
|
||||
import org.elasticsearch.index.rankeval.DiscountedCumulativeGain;
|
||||
import org.elasticsearch.index.rankeval.EvaluationMetric;
|
||||
import org.elasticsearch.index.rankeval.ExpectedReciprocalRank;
|
||||
import org.elasticsearch.index.rankeval.MeanReciprocalRank;
|
||||
import org.elasticsearch.index.rankeval.MetricDetail;
|
||||
import org.elasticsearch.index.rankeval.PrecisionAtK;
|
||||
|
@ -616,7 +618,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(8, namedXContents.size());
|
||||
assertEquals(10, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
|
@ -630,14 +632,16 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
assertEquals(Integer.valueOf(2), categories.get(Aggregation.class));
|
||||
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
|
||||
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(EvaluationMetric.class));
|
||||
assertEquals(Integer.valueOf(4), categories.get(EvaluationMetric.class));
|
||||
assertTrue(names.contains(PrecisionAtK.NAME));
|
||||
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
|
||||
assertTrue(names.contains(MeanReciprocalRank.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(MetricDetail.class));
|
||||
assertTrue(names.contains(ExpectedReciprocalRank.NAME));
|
||||
assertEquals(Integer.valueOf(4), categories.get(MetricDetail.class));
|
||||
assertTrue(names.contains(PrecisionAtK.NAME));
|
||||
assertTrue(names.contains(MeanReciprocalRank.NAME));
|
||||
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
|
||||
assertTrue(names.contains(ExpectedReciprocalRank.NAME));
|
||||
}
|
||||
|
||||
public void testApiNamingConventions() throws Exception {
|
||||
|
|
|
@ -65,6 +65,9 @@ public class ExpectedReciprocalRank implements EvaluationMetric {
|
|||
|
||||
public static final String NAME = "expected_reciprocal_rank";
|
||||
|
||||
/**
|
||||
* @param maxRelevance the highest expected relevance in the data
|
||||
*/
|
||||
public ExpectedReciprocalRank(int maxRelevance) {
|
||||
this(maxRelevance, null, DEFAULT_K);
|
||||
}
|
||||
|
|
|
@ -37,12 +37,17 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider {
|
|||
MeanReciprocalRank::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME),
|
||||
DiscountedCumulativeGain::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(ExpectedReciprocalRank.NAME),
|
||||
ExpectedReciprocalRank::fromXContent));
|
||||
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME),
|
||||
PrecisionAtK.Detail::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME),
|
||||
MeanReciprocalRank.Detail::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME),
|
||||
DiscountedCumulativeGain.Detail::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(ExpectedReciprocalRank.NAME),
|
||||
ExpectedReciprocalRank.Detail::fromXContent));
|
||||
return namedXContent;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,10 +60,14 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
|
|||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new));
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(MetricDetail.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank.Detail::new));
|
||||
return namedWriteables;
|
||||
}
|
||||
|
||||
|
|
|
@ -161,3 +161,37 @@ setup:
|
|||
- match: {details.berlin_query.metric_details.mean_reciprocal_rank: {"first_relevant": 2}}
|
||||
- match: {details.berlin_query.unrated_docs: [ {"_index": "foo", "_id": "doc1"}]}
|
||||
|
||||
---
|
||||
"Expected Reciprocal Rank":
|
||||
|
||||
- skip:
|
||||
version: " - 6.3.99"
|
||||
reason: ERR was introduced in 6.4
|
||||
|
||||
- do:
|
||||
rank_eval:
|
||||
body: {
|
||||
"requests" : [
|
||||
{
|
||||
"id": "amsterdam_query",
|
||||
"request": { "query": { "match" : {"text" : "amsterdam" }}},
|
||||
"ratings": [{"_index": "foo", "_id": "doc4", "rating": 1}]
|
||||
},
|
||||
{
|
||||
"id" : "berlin_query",
|
||||
"request": { "query": { "match" : { "text" : "berlin" } }, "size" : 10 },
|
||||
"ratings": [{"_index": "foo", "_id": "doc4", "rating": 1}]
|
||||
}
|
||||
],
|
||||
"metric" : {
|
||||
"expected_reciprocal_rank": {
|
||||
"maximum_relevance" : 1,
|
||||
"k" : 5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- gt: {metric_score: 0.2083333}
|
||||
- lt: {metric_score: 0.2083334}
|
||||
- match: {details.amsterdam_query.metric_details.expected_reciprocal_rank.unrated_docs: 2}
|
||||
- match: {details.berlin_query.metric_details.expected_reciprocal_rank.unrated_docs: 1}
|
||||
|
|
Loading…
Reference in New Issue