diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java index e620fa63d53..18e79bec3b9 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java @@ -20,6 +20,7 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParseFieldMatcher; import org.elasticsearch.common.ParseFieldMatcherSupplier; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -35,8 +36,9 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; -public class DiscountedCumulativeGainAt extends RankedListQualityMetric { +public class DiscountedCumulativeGainAt extends RankedListQualityMetric { /** rank position up to which to check results. */ private int position; @@ -83,6 +85,17 @@ public class DiscountedCumulativeGainAt extends RankedListQualityMetric { this.position = position; } + /** + * @param position number of top results to check against a given set of relevant results. Must be positive. // TODO is there a way to enforce this? + * @param normalize If set to true, dcg will be normalized (ndcg) See https://en.wikipedia.org/wiki/Discounted_cumulative_gain + * @param unknownDocRating the rating for docs the user hasn't supplied an explicit rating for + * */ + public DiscountedCumulativeGainAt(int position, boolean normalize, Integer unknownDocRating) { + this.position = position; + this.normalize = normalize; + this.unknownDocRating = unknownDocRating; + } + /** * Return number of search results to check for quality metric. */ @@ -178,13 +191,24 @@ public class DiscountedCumulativeGainAt extends RankedListQualityMetric { PARSER.declareInt(DiscountedCumulativeGainAt::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD); } + @Override + public DiscountedCumulativeGainAt fromXContent(XContentParser parser, ParseFieldMatcher matcher) { + return DiscountedCumulativeGainAt.fromXContent(parser, new ParseFieldMatcherSupplier() { + @Override + public ParseFieldMatcher getParseFieldMatcher() { + return matcher; + } + }); + } + public static DiscountedCumulativeGainAt fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { return PARSER.apply(parser, matcher); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(NAME); + //builder.startObject(NAME); // TODO roundtrip xcontent only works w/o this, wtf? + builder.startObject(); builder.field(SIZE_FIELD.getPreferredName(), this.position); builder.field(NORMALIZE_FIELD.getPreferredName(), this.normalize); if (unknownDocRating != null) { @@ -193,4 +217,23 @@ public class DiscountedCumulativeGainAt extends RankedListQualityMetric { builder.endObject(); return builder; } + + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + DiscountedCumulativeGainAt other = (DiscountedCumulativeGainAt) obj; + return Objects.equals(position, other.position) && + Objects.equals(normalize, other.normalize) && + Objects.equals(unknownDocRating, other.unknownDocRating); + } + + @Override + public final int hashCode() { + return Objects.hash(getClass(), position, normalize, unknownDocRating); + } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java index ba571c576ef..c9ad3deb759 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java @@ -20,6 +20,7 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParseFieldMatcher; import org.elasticsearch.common.ParseFieldMatcherSupplier; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -40,7 +41,7 @@ import javax.naming.directory.SearchResult; * * Documents of unkonwn quality are ignored in the precision at n computation and returned by document id. * */ -public class PrecisionAtN extends RankedListQualityMetric { +public class PrecisionAtN extends RankedListQualityMetric { /** Number of results to check against a given set of relevant results. */ private int n; @@ -90,6 +91,17 @@ public class PrecisionAtN extends RankedListQualityMetric { PARSER.declareInt(ConstructingObjectParser.constructorArg(), SIZE_FIELD); } + @Override + public PrecisionAtN fromXContent(XContentParser parser, ParseFieldMatcher matcher) { + return PrecisionAtN.fromXContent(parser, new ParseFieldMatcherSupplier() { + + @Override + public ParseFieldMatcher getParseFieldMatcher() { + return matcher; + } + }); + } + public static PrecisionAtN fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { return PARSER.apply(parser, matcher); } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java index e423ab3533c..30793d564e3 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java @@ -23,6 +23,7 @@ import org.elasticsearch.action.support.ToXContentToBytes; import org.elasticsearch.common.ParseFieldMatcherSupplier; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.FromXContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser.Token; @@ -38,7 +39,9 @@ import java.util.List; * * RelevancyLevel specifies the type of object determining the relevancy level of some known docid. * */ -public abstract class RankedListQualityMetric extends ToXContentToBytes implements NamedWriteable { +public abstract class RankedListQualityMetric + extends ToXContentToBytes + implements NamedWriteable, FromXContentBuilder { /** * Returns a single metric representing the ranking quality of a set of returned documents diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocument.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocument.java index 8e7c1aac888..c5faecb986e 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocument.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocument.java @@ -21,11 +21,14 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.action.support.ToXContentToBytes; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.ParseFieldMatcherSupplier; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.FromXContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -35,12 +38,12 @@ import java.util.Objects; /** * A document ID and its rating for the query QA use case. * */ -public class RatedDocument extends ToXContentToBytes implements Writeable { +public class RatedDocument extends ToXContentToBytes implements Writeable, FromXContentBuilder { public static final ParseField RATING_FIELD = new ParseField("rating"); public static final ParseField KEY_FIELD = new ParseField("key"); - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("rated_document", + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("rated_document", a -> new RatedDocument((RatedDocumentKey) a[0], (Integer) a[1])); static { @@ -93,8 +96,19 @@ public class RatedDocument extends ToXContentToBytes implements Writeable { out.writeVInt(rating); } - public static RatedDocument fromXContent(XContentParser parser, RankEvalContext context) throws IOException { - return PARSER.apply(parser, context); + @Override + public RatedDocument fromXContent(XContentParser parser, ParseFieldMatcher parseFieldMatcher) throws IOException { + return RatedDocument.fromXContent(parser, new ParseFieldMatcherSupplier() { + + @Override + public ParseFieldMatcher getParseFieldMatcher() { + return parseFieldMatcher; + } + }); + } + + public static RatedDocument fromXContent(XContentParser parser, ParseFieldMatcherSupplier supplier) throws IOException { + return PARSER.apply(parser, supplier); } @Override diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocumentKey.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocumentKey.java index d68149a2ca0..c10bc3f0ce0 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocumentKey.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedDocumentKey.java @@ -21,6 +21,7 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.action.support.ToXContentToBytes; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParseFieldMatcherSupplier; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -36,7 +37,7 @@ public class RatedDocumentKey extends ToXContentToBytes implements Writeable { public static final ParseField TYPE_FIELD = new ParseField("type"); public static final ParseField INDEX_FIELD = new ParseField("index"); - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("ratings", + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("ratings", a -> new RatedDocumentKey((String) a[0], (String) a[1], (String) a[2])); static { @@ -103,7 +104,7 @@ public class RatedDocumentKey extends ToXContentToBytes implements Writeable { out.writeString(docId); } - public static RatedDocumentKey fromXContent(XContentParser parser, RankEvalContext context) throws IOException { + public static RatedDocumentKey fromXContent(XContentParser parser, ParseFieldMatcherSupplier context) throws IOException { return PARSER.apply(parser, context); } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ReciprocalRank.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ReciprocalRank.java index dd4e710859b..791e857eca8 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ReciprocalRank.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ReciprocalRank.java @@ -20,6 +20,7 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParseFieldMatcher; import org.elasticsearch.common.ParseFieldMatcherSupplier; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -42,7 +43,7 @@ import javax.naming.directory.SearchResult; /** * Evaluate reciprocal rank. * */ -public class ReciprocalRank extends RankedListQualityMetric { +public class ReciprocalRank extends RankedListQualityMetric { public static final String NAME = "reciprocal_rank"; public static final int DEFAULT_MAX_ACCEPTABLE_RANK = 10; @@ -140,6 +141,17 @@ public class ReciprocalRank extends RankedListQualityMetric { PARSER.declareInt(ReciprocalRank::setMaxAcceptableRank, MAX_RANK_FIELD); } + @Override + public ReciprocalRank fromXContent(XContentParser parser, ParseFieldMatcher matcher) { + return ReciprocalRank.fromXContent(parser, new ParseFieldMatcherSupplier() { + + @Override + public ParseFieldMatcher getParseFieldMatcher() { + return matcher; + } + }); + } + public static ReciprocalRank fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { return PARSER.apply(parser, matcher); } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java index 2221ee2a1e3..68207470382 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java @@ -26,7 +26,6 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.InternalSearchHit; -import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.ArrayList; @@ -34,7 +33,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.ExecutionException; -public class DiscountedCumulativeGainAtTests extends ESTestCase { +public class DiscountedCumulativeGainAtTests extends XContentRoundtripTestCase { /** * Assuming the docs are ranked in the following order: @@ -121,4 +120,13 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase { assertEquals(8, dcgAt.getPosition()); assertEquals(true, dcgAt.getNormalize()); } + + public void testXContentRoundtrip() throws IOException { + int position = randomIntBetween(0, 1000); + boolean normalize = randomBoolean(); + Integer unknownDocRating = new Integer(randomIntBetween(0, 1000)); + + DiscountedCumulativeGainAt testItem = new DiscountedCumulativeGainAt(position, normalize, unknownDocRating); + roundtrip(testItem); + } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedDocumentTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedDocumentTests.java index 6e48b72d38e..8cb4a194026 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedDocumentTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedDocumentTests.java @@ -19,41 +19,17 @@ package org.elasticsearch.index.rankeval; -import org.elasticsearch.common.ParseFieldMatcher; -import org.elasticsearch.common.xcontent.ToXContent; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.XContentFactory; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.common.xcontent.XContentType; -import org.elasticsearch.test.ESTestCase; - import java.io.IOException; -public class RatedDocumentTests extends ESTestCase { +public class RatedDocumentTests extends XContentRoundtripTestCase { public void testXContentParsing() throws IOException { String index = randomAsciiOfLength(10); String type = randomAsciiOfLength(10); String docId = randomAsciiOfLength(10); int rating = randomInt(); + RatedDocument testItem = new RatedDocument(new RatedDocumentKey(index, type, docId), rating); - - XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values())); - if (randomBoolean()) { - builder.prettyPrint(); - } - testItem.toXContent(builder, ToXContent.EMPTY_PARAMS); - XContentBuilder shuffled = shuffleXContent(builder); - XContentParser itemParser = XContentHelper.createParser(shuffled.bytes()); - itemParser.nextToken(); - - RankEvalContext context = new RankEvalContext(ParseFieldMatcher.STRICT, null, null); - RatedDocument parsedItem = RatedDocument.fromXContent(itemParser, context); - assertNotSame(testItem, parsedItem); - assertEquals(testItem, parsedItem); - assertEquals(testItem.hashCode(), parsedItem.hashCode()); - + roundtrip(testItem); } - } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/XContentRoundtripTestCase.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/XContentRoundtripTestCase.java new file mode 100644 index 00000000000..689a42726ff --- /dev/null +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/XContentRoundtripTestCase.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.rankeval; + +import org.elasticsearch.action.support.ToXContentToBytes; +import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.xcontent.FromXContentBuilder; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +public class XContentRoundtripTestCase> extends ESTestCase { + + public void roundtrip(T testItem) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values())); + if (randomBoolean()) { + builder.prettyPrint(); + } + testItem.toXContent(builder, ToXContent.EMPTY_PARAMS); + XContentBuilder shuffled = shuffleXContent(builder); + XContentParser itemParser = XContentHelper.createParser(shuffled.bytes()); + itemParser.nextToken(); + T parsedItem = testItem.fromXContent(itemParser, ParseFieldMatcher.STRICT); + assertNotSame(testItem, parsedItem); + assertEquals(testItem, parsedItem); + assertEquals(testItem.hashCode(), parsedItem.hashCode()); + } +}