From c8d9d063ca21f9c755b0fbe90d594d46fcf0eca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Thu, 3 Nov 2016 17:33:39 +0100 Subject: [PATCH] Removing the 'size' parameter from the dcg metric --- ...nAt.java => DiscountedCumulativeGain.java} | 68 ++++--------------- .../index/rankeval/RankEvalPlugin.java | 4 +- .../rankeval/RankedListQualityMetric.java | 4 +- .../index/rankeval/ReciprocalRank.java | 2 +- ...ava => DiscountedCumulativeGainTests.java} | 44 ++++++------ .../index/rankeval/RankEvalSpecTests.java | 2 +- .../index/rankeval/ReciprocalRankTests.java | 3 +- .../rest-api-spec/test/rank_eval/20_dcg.yaml | 6 +- 8 files changed, 47 insertions(+), 86 deletions(-) rename modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/{DiscountedCumulativeGainAt.java => DiscountedCumulativeGain.java} (70%) rename modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/{DiscountedCumulativeGainAtTests.java => DiscountedCumulativeGainTests.java} (86%) diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java similarity index 70% rename from modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java rename to modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java index 4f5718a702e..a82e8ad65ff 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java @@ -38,56 +38,36 @@ import java.util.stream.Collectors; import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings; -public class DiscountedCumulativeGainAt implements RankedListQualityMetric { +public class DiscountedCumulativeGain implements RankedListQualityMetric { - /** rank position up to which to check results. */ - private int position; /** If set to true, the dcg will be normalized (ndcg) */ private boolean normalize; /** If set to, this will be the rating for docs the user hasn't supplied an explicit rating for */ private Integer unknownDocRating; - public static final String NAME = "dcg_at_n"; + public static final String NAME = "dcg"; private static final double LOG2 = Math.log(2.0); - /** - * Initialises position with 10 - * */ - public DiscountedCumulativeGainAt() { - this.position = 10; + public DiscountedCumulativeGain() { } /** - * @param position number of top results to check against a given set of relevant results. Must be positive. - */ - public DiscountedCumulativeGainAt(int position) { - if (position <= 0) { - throw new IllegalArgumentException("number of results to check needs to be positive but was " + position); - } - this.position = position; - } - - /** - * @param position number of top results to check against a given set of relevant results. Must be positive. * @param normalize If set to true, dcg will be normalized (ndcg) * See https://en.wikipedia.org/wiki/Discounted_cumulative_gain * @param unknownDocRating the rating for docs the user hasn't supplied an explicit rating for * */ - public DiscountedCumulativeGainAt(int position, boolean normalize, Integer unknownDocRating) { - this(position); + public DiscountedCumulativeGain( boolean normalize, Integer unknownDocRating) { this.normalize = normalize; this.unknownDocRating = unknownDocRating; } - public DiscountedCumulativeGainAt(StreamInput in) throws IOException { - this(in.readInt()); + public DiscountedCumulativeGain(StreamInput in) throws IOException { normalize = in.readBoolean(); unknownDocRating = in.readOptionalVInt(); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeInt(position); out.writeBoolean(normalize); out.writeOptionalVInt(unknownDocRating); } @@ -97,20 +77,6 @@ public class DiscountedCumulativeGainAt implements RankedListQualityMetric { return NAME; } - /** - * Return number of search results to check for quality metric. - */ - public int getPosition() { - return this.position; - } - - /** - * set number of search results to check for quality metric. - */ - public void setPosition(int position) { - this.position = position; - } - /** * If set to true, the dcg will be normalized (ndcg) */ @@ -143,8 +109,8 @@ public class DiscountedCumulativeGainAt implements RankedListQualityMetric { public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List ratedDocs) { List allRatings = ratedDocs.stream().mapToInt(RatedDocument::getRating).boxed().collect(Collectors.toList()); List ratedHits = joinHitsWithRatings(hits, ratedDocs); - List ratingsInSearchHits = new ArrayList<>(Math.min(ratedHits.size(), position)); - for (RatedSearchHit hit : ratedHits.subList(0, position)) { + List ratingsInSearchHits = new ArrayList<>(ratedHits.size()); + for (RatedSearchHit hit : ratedHits) { // unknownDocRating might be null, which means it will be unrated docs are ignored in the dcg calculation // we still need to add them as a placeholder so the rank of the subsequent ratings is correct ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating)); @@ -173,19 +139,17 @@ public class DiscountedCumulativeGainAt implements RankedListQualityMetric { return dcg; } - private static final ParseField SIZE_FIELD = new ParseField("size"); private static final ParseField NORMALIZE_FIELD = new ParseField("normalize"); private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating"); - private static final ObjectParser PARSER = - new ObjectParser<>("dcg_at", () -> new DiscountedCumulativeGainAt()); + private static final ObjectParser PARSER = + new ObjectParser<>("dcg_at", () -> new DiscountedCumulativeGain()); static { - PARSER.declareInt(DiscountedCumulativeGainAt::setPosition, SIZE_FIELD); - PARSER.declareBoolean(DiscountedCumulativeGainAt::setNormalize, NORMALIZE_FIELD); - PARSER.declareInt(DiscountedCumulativeGainAt::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD); + PARSER.declareBoolean(DiscountedCumulativeGain::setNormalize, NORMALIZE_FIELD); + PARSER.declareInt(DiscountedCumulativeGain::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD); } - public static DiscountedCumulativeGainAt fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { + public static DiscountedCumulativeGain fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { return PARSER.apply(parser, matcher); } @@ -193,7 +157,6 @@ public class DiscountedCumulativeGainAt implements RankedListQualityMetric { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.startObject(NAME); - builder.field(SIZE_FIELD.getPreferredName(), this.position); builder.field(NORMALIZE_FIELD.getPreferredName(), this.normalize); if (unknownDocRating != null) { builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating); @@ -211,15 +174,14 @@ public class DiscountedCumulativeGainAt implements RankedListQualityMetric { if (obj == null || getClass() != obj.getClass()) { return false; } - DiscountedCumulativeGainAt other = (DiscountedCumulativeGainAt) obj; - return Objects.equals(position, other.position) && - Objects.equals(normalize, other.normalize) && + DiscountedCumulativeGain other = (DiscountedCumulativeGain) obj; + return Objects.equals(normalize, other.normalize) && Objects.equals(unknownDocRating, other.unknownDocRating); } @Override public final int hashCode() { - return Objects.hash(position, normalize, unknownDocRating); + return Objects.hash(normalize, unknownDocRating); } // TODO maybe also add debugging breakdown here diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java index c2261646867..d4d6a3a22ea 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java @@ -52,8 +52,8 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin { List namedWriteables = new ArrayList<>(); namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, Precision.NAME, Precision::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, ReciprocalRank.NAME, ReciprocalRank::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGainAt.NAME, - DiscountedCumulativeGainAt::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGain.NAME, + DiscountedCumulativeGain::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, Precision.NAME, Precision.Breakdown::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, ReciprocalRank.NAME, diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java index ae25f0e2b70..cd1f6cb0e85 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java @@ -70,8 +70,8 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable { case ReciprocalRank.NAME: rc = ReciprocalRank.fromXContent(parser, context); break; - case DiscountedCumulativeGainAt.NAME: - rc = DiscountedCumulativeGainAt.fromXContent(parser, context); + case DiscountedCumulativeGain.NAME: + rc = DiscountedCumulativeGain.fromXContent(parser, context); break; default: throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName); diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ReciprocalRank.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ReciprocalRank.java index caf52c0e17c..160cc0aee77 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ReciprocalRank.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ReciprocalRank.java @@ -45,7 +45,7 @@ import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsW public class ReciprocalRank implements RankedListQualityMetric { public static final String NAME = "reciprocal_rank"; - public static final int DEFAULT_MAX_ACCEPTABLE_RANK = 10; + public static final int DEFAULT_MAX_ACCEPTABLE_RANK = Integer.MAX_VALUE; private int maxAcceptableRank = DEFAULT_MAX_ACCEPTABLE_RANK; /** ratings equal or above this value will be considered relevant. */ diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java similarity index 86% rename from modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java rename to modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java index 9c5a5e3bb14..32a4d0dd82a 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java @@ -30,14 +30,12 @@ import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.concurrent.ExecutionException; import static org.elasticsearch.index.rankeval.RankedListQualityMetric.filterUnknownDocuments; -public class DiscountedCumulativeGainAtTests extends ESTestCase { +public class DiscountedCumulativeGainTests extends ESTestCase { /** * Assuming the docs are ranked in the following order: @@ -53,7 +51,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase { * * dcg = 13.84826362927298 (sum of last column) */ - public void testDCGAt() throws IOException, InterruptedException, ExecutionException { + public void testDCGAt() { List rated = new ArrayList<>(); int[] relevanceRatings = new int[] { 3, 2, 3, 0, 1, 2 }; InternalSearchHit[] hits = new InternalSearchHit[6]; @@ -62,7 +60,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase { hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap()); hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0))); } - DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(6); + DiscountedCumulativeGain dcg = new DiscountedCumulativeGain(); assertEquals(13.84826362927298, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001); /** @@ -97,7 +95,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase { * * dcg = 12.779642067948913 (sum of last column) */ - public void testDCGAtSixMissingRatings() throws IOException, InterruptedException, ExecutionException { + public void testDCGAtSixMissingRatings() { List rated = new ArrayList<>(); Integer[] relevanceRatings = new Integer[] { 3, 2, 3, null, 1}; InternalSearchHit[] hits = new InternalSearchHit[6]; @@ -110,7 +108,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase { hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap()); hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0))); } - DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(6); + DiscountedCumulativeGain dcg = new DiscountedCumulativeGain(); EvalQueryQuality result = dcg.evaluate("id", hits, rated); assertEquals(12.779642067948913, result.getQualityLevel(), 0.00001); assertEquals(2, filterUnknownDocuments(result.getHitsAndRatings()).size()); @@ -149,21 +147,24 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase { * * dcg = 12.392789260714371 (sum of last column until position 4) */ - public void testDCGAtFourMoreRatings() throws IOException, InterruptedException, ExecutionException { - List rated = new ArrayList<>(); + public void testDCGAtFourMoreRatings() { Integer[] relevanceRatings = new Integer[] { 3, 2, 3, null, 1, null}; - InternalSearchHit[] hits = new InternalSearchHit[6]; + List ratedDocs = new ArrayList<>(); for (int i = 0; i < 6; i++) { if (i < relevanceRatings.length) { if (relevanceRatings[i] != null) { - rated.add(new RatedDocument("index", "type", Integer.toString(i), relevanceRatings[i])); + ratedDocs.add(new RatedDocument("index", "type", Integer.toString(i), relevanceRatings[i])); } } + } + // only create four hits + InternalSearchHit[] hits = new InternalSearchHit[4]; + for (int i = 0; i < 4; i++) { hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap()); hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0))); } - DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(4); - EvalQueryQuality result = dcg.evaluate("id", Arrays.copyOfRange(hits, 0, 4), rated); + DiscountedCumulativeGain dcg = new DiscountedCumulativeGain(); + EvalQueryQuality result = dcg.evaluate("id", hits, ratedDocs); assertEquals(12.392789260714371 , result.getQualityLevel(), 0.00001); assertEquals(1, filterUnknownDocuments(result.getHitsAndRatings()).size()); @@ -183,33 +184,32 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase { * idcg = 13.347184833073591 (sum of last column) */ dcg.setNormalize(true); - assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001); + assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001); } public void testParseFromXContent() throws IOException { String xContent = " {\n" - + " \"size\": 8,\n" + + " \"unknown_doc_rating\": 2,\n" + " \"normalize\": true\n" + "}"; XContentParser parser = XContentFactory.xContent(xContent).createParser(xContent); - DiscountedCumulativeGainAt dcgAt = DiscountedCumulativeGainAt.fromXContent(parser, () -> ParseFieldMatcher.STRICT); - assertEquals(8, dcgAt.getPosition()); + DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser, () -> ParseFieldMatcher.STRICT); + assertEquals(2, dcgAt.getUnknownDocRating().intValue()); assertEquals(true, dcgAt.getNormalize()); } - public static DiscountedCumulativeGainAt createTestItem() { - int position = randomIntBetween(0, 1000); + public static DiscountedCumulativeGain createTestItem() { boolean normalize = randomBoolean(); Integer unknownDocRating = new Integer(randomIntBetween(0, 1000)); - return new DiscountedCumulativeGainAt(position, normalize, unknownDocRating); + return new DiscountedCumulativeGain(normalize, unknownDocRating); } public void testXContentRoundtrip() throws IOException { - DiscountedCumulativeGainAt testItem = createTestItem(); + DiscountedCumulativeGain testItem = createTestItem(); XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem); itemParser.nextToken(); itemParser.nextToken(); - DiscountedCumulativeGainAt parsedItem = DiscountedCumulativeGainAt.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT); + DiscountedCumulativeGain parsedItem = DiscountedCumulativeGain.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT); assertNotSame(testItem, parsedItem); assertEquals(testItem, parsedItem); assertEquals(testItem.hashCode(), parsedItem.hashCode()); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java index c5685d18a15..b7a6746d965 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java @@ -93,7 +93,7 @@ public class RankEvalSpecTests extends ESTestCase { if (randomBoolean()) { metric = PrecisionTests.createTestItem(); } else { - metric = DiscountedCumulativeGainAtTests.createTestItem(); + metric = DiscountedCumulativeGainTests.createTestItem(); } RankEvalSpec testItem = new RankEvalSpec(specs, metric); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java index 763e6d22365..9220d501c84 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java @@ -34,7 +34,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Vector; -import java.util.concurrent.ExecutionException; public class ReciprocalRankTests extends ESTestCase { @@ -98,7 +97,7 @@ public class ReciprocalRankTests extends ESTestCase { * e.g. we set it to 2 here and expect dics 0-2 to be not relevant, so first relevant doc has * third ranking position, so RR should be 1/3 */ - public void testPrecisionAtFiveRelevanceThreshold() throws IOException, InterruptedException, ExecutionException { + public void testPrecisionAtFiveRelevanceThreshold() { List rated = new ArrayList<>(); rated.add(new RatedDocument("test", "testtype", "0", 0)); rated.add(new RatedDocument("test", "testtype", "1", 1)); diff --git a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/20_dcg.yaml b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/20_dcg.yaml index 8310389c02a..2f784790ad5 100644 --- a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/20_dcg.yaml +++ b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/20_dcg.yaml @@ -60,7 +60,7 @@ {"_index" : "foo", "_type" : "bar", "_id" : "doc6", "rating": 2}] } ], - "metric" : { "dcg_at_n": { "size": 6}} + "metric" : { "dcg": {}} } - match: {rank_eval.quality_level: 13.84826362927298} @@ -85,7 +85,7 @@ {"_index" : "foo", "_type" : "bar", "_id" : "doc6", "rating": 2}] }, ], - "metric" : { "dcg_at_n": { "size": 6}} + "metric" : { "dcg": { }} } - match: {rank_eval.quality_level: 10.29967439154499} @@ -121,7 +121,7 @@ {"_index" : "foo", "_type" : "bar", "_id" : "doc6", "rating": 2}] }, ], - "metric" : { "dcg_at_n": { "size": 6}} + "metric" : { "dcg": { }} } - match: {rank_eval.quality_level: 12.073969010408984}