diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/Precision.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/Precision.java index 4871a1e5359..f85e9460120 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/Precision.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/Precision.java @@ -45,29 +45,40 @@ import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsW */ public class Precision implements RankedListQualityMetric { - /** ratings equal or above this value will be considered relevant. */ - private int relevantRatingThreshhold = 1; - public static final String NAME = "precision"; private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold"); + private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled"); private static final ObjectParser PARSER = new ObjectParser<>(NAME, Precision::new); + /** + * This setting controls how unlabeled documents in the search hits are + * treated. Set to 'true', unlabeled documents are ignored and neither count + * as true or false positives. Set to 'false', they are treated as false positives. + */ + private boolean ignoreUnlabeled = false; + + /** ratings equal or above this value will be considered relevant. */ + private int relevantRatingThreshhold = 1; + public Precision() { // needed for supplier in parser } static { PARSER.declareInt(Precision::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD); + PARSER.declareBoolean(Precision::setIgnoreUnlabeled, IGNORE_UNLABELED_FIELD); } public Precision(StreamInput in) throws IOException { relevantRatingThreshhold = in.readOptionalVInt(); + ignoreUnlabeled = in.readOptionalBoolean(); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeOptionalVInt(relevantRatingThreshhold); + out.writeOptionalBoolean(ignoreUnlabeled); } @Override @@ -90,6 +101,20 @@ public class Precision implements RankedListQualityMetric { return relevantRatingThreshhold ; } + /** + * Sets the 'ìgnore_unlabeled' parameter + * */ + public void setIgnoreUnlabeled(boolean ignoreUnlabeled) { + this.ignoreUnlabeled = ignoreUnlabeled; + } + + /** + * Gets the 'ìgnore_unlabeled' parameter + * */ + public boolean getIgnoreUnlabeled() { + return ignoreUnlabeled; + } + public static Precision fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { return PARSER.apply(parser, matcher); } @@ -110,6 +135,8 @@ public class Precision implements RankedListQualityMetric { } else { falsePositives++; } + } else if (ignoreUnlabeled == false) { + falsePositives++; } } double precision = 0.0; @@ -122,35 +149,12 @@ public class Precision implements RankedListQualityMetric { return evalQueryQuality; } - // TODO add abstraction that also works for other metrics - public enum Rating { - IRRELEVANT, RELEVANT; - } - - /** - * Needed to get the enum accross serialisation boundaries. - * */ - public static class RatingMapping { - public static Integer mapFrom(Rating rating) { - if (Rating.RELEVANT.equals(rating)) { - return 1; - } - return 0; - } - - public static Rating mapTo(Integer rating) { - if (rating == 1) { - return Rating.RELEVANT; - } - return Rating.IRRELEVANT; - } - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.startObject(NAME); builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold); + builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.ignoreUnlabeled); builder.endObject(); builder.endObject(); return builder; @@ -165,12 +169,13 @@ public class Precision implements RankedListQualityMetric { return false; } Precision other = (Precision) obj; - return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold); + return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold) && + Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled); } @Override public final int hashCode() { - return Objects.hash(relevantRatingThreshhold); + return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled); } public static class Breakdown implements MetricDetails { diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionTests.java index 5854d4cab1a..75c7818f40e 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionTests.java @@ -24,7 +24,6 @@ import org.elasticsearch.common.text.Text; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.Index; -import org.elasticsearch.index.rankeval.Precision.Rating; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.InternalSearchHit; @@ -32,6 +31,7 @@ import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Vector; @@ -106,6 +106,29 @@ public class PrecisionTests extends ESTestCase { assertEquals(3, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved()); } + public void testIgnoreUnlabeled() { + List rated = new ArrayList<>(); + rated.add(new RatedDocument("test", "testtype", "0", Rating.RELEVANT.ordinal())); + rated.add(new RatedDocument("test", "testtype", "1", Rating.RELEVANT.ordinal())); + // add an unlabeled search hit + SearchHit[] searchHits = Arrays.copyOf(toSearchHits(rated, "test", "testtype"), 3); + searchHits[2] = new InternalSearchHit(2, "2", new Text("testtype"), Collections.emptyMap()); + ((InternalSearchHit)searchHits[2]).shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0)); + + EvalQueryQuality evaluated = (new Precision()).evaluate("id", searchHits, rated); + assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001); + assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved()); + assertEquals(3, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved()); + + // also try with setting `ignore_unlabeled` + Precision prec = new Precision(); + prec.setIgnoreUnlabeled(true); + evaluated = prec.evaluate("id", searchHits, rated); + assertEquals((double) 2 / 2, evaluated.getQualityLevel(), 0.00001); + assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved()); + assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved()); + } + public void testNoRatedDocs() throws Exception { InternalSearchHit[] hits = new InternalSearchHit[5]; for (int i = 0; i < 5; i++) { @@ -115,6 +138,14 @@ public class PrecisionTests extends ESTestCase { EvalQueryQuality evaluated = (new Precision()).evaluate("id", hits, Collections.emptyList()); assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001); assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved()); + assertEquals(5, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved()); + + // also try with setting `ignore_unlabeled` + Precision prec = new Precision(); + prec.setIgnoreUnlabeled(true); + evaluated = prec.evaluate("id", hits, Collections.emptyList()); + assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001); + assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved()); assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved()); } @@ -141,6 +172,7 @@ public class PrecisionTests extends ESTestCase { if (randomBoolean()) { precision.setRelevantRatingThreshhold(randomIntBetween(0, 10)); } + precision.setIgnoreUnlabeled(randomBoolean()); return precision; } @@ -163,4 +195,8 @@ public class PrecisionTests extends ESTestCase { } return hits; } + + public enum Rating { + IRRELEVANT, RELEVANT; + } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java index b250e9f9cfb..6ce18182fce 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java @@ -22,13 +22,12 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.RangeQueryBuilder; -import org.elasticsearch.index.rankeval.Precision.Rating; +import org.elasticsearch.index.rankeval.PrecisionTests.Rating; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.ESIntegTestCase; import org.junit.Before; -import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -70,7 +69,7 @@ public class RankEvalRequestTests extends ESIntegTestCase { refresh(); } - public void testPrecisionAtRequest() throws IOException { + public void testPrecisionAtRequest() { List indices = Arrays.asList(new String[] { "test" }); List types = Arrays.asList(new String[] { "testtype" }); @@ -84,7 +83,9 @@ public class RankEvalRequestTests extends ESIntegTestCase { berlinRequest.setSummaryFields(Arrays.asList(new String[]{ "text", "title" })); specifications.add(berlinRequest); - RankEvalSpec task = new RankEvalSpec(specifications, new Precision()); + Precision metric = new Precision(); + metric.setIgnoreUnlabeled(true); + RankEvalSpec task = new RankEvalSpec(specifications, metric); RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); builder.setRankEvalSpec(task); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java index 3f00562d1e1..4ce34ebc0d1 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java @@ -23,7 +23,7 @@ import org.elasticsearch.common.ParseFieldMatcher; import org.elasticsearch.common.text.Text; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.Index; -import org.elasticsearch.index.rankeval.Precision.Rating; +import org.elasticsearch.index.rankeval.PrecisionTests.Rating; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.InternalSearchHit; diff --git a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yaml b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yaml index ea705bdd33d..a0170f9722a 100644 --- a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yaml +++ b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yaml @@ -56,7 +56,7 @@ "ratings": [{"_index": "foo", "_type": "bar", "_id": "doc1", "rating": 1}] } ], - "metric" : { "precision": { }} + "metric" : { "precision": { "ignore_unlabeled" : true }} } - match: { rank_eval.quality_level: 1}