From 0578a964835c3328457965b1fe7e96e1234ead1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Mon, 8 Aug 2016 13:29:44 +0200 Subject: [PATCH 1/3] Merge RankEvalResult with Response The current response object only serves as a wrapper around the result object. This change merges the two classes into one. --- .../index/rankeval/RankEvalResponse.java | 56 +++++++++---- .../index/rankeval/RankEvalResult.java | 80 ------------------- .../rankeval/TransportRankEvalAction.java | 4 +- .../index/rankeval/RankEvalRequestTests.java | 17 +--- 4 files changed, 43 insertions(+), 114 deletions(-) delete mode 100644 modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResult.java diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java index 0e520febdb8..da8651e6a24 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResponse.java @@ -38,42 +38,64 @@ import java.util.Map; * Documents of unknown quality - i.e. those that haven't been supplied in the set of annotated documents but have been returned * by the search are not taken into consideration when computing precision at n - they are ignored. * - * TODO get rid of either this or RankEvalResult **/ +//TODO instead of just returning averages over complete results, think of other statistics, micro avg, macro avg, partial results public class RankEvalResponse extends ActionResponse implements ToXContent { - - private RankEvalResult qualityResult; + /**ID of QA specification this result was generated for.*/ + private String specId; + /**Average precision observed when issuing query intents with this specification.*/ + private double qualityLevel; + /**Mapping from intent id to all documents seen for this intent that were not annotated.*/ + private Map> unknownDocs; public RankEvalResponse() { - } + @SuppressWarnings("unchecked") public RankEvalResponse(StreamInput in) throws IOException { super.readFrom(in); - this.qualityResult = new RankEvalResult(in); + this.specId = in.readString(); + this.qualityLevel = in.readDouble(); + this.unknownDocs = (Map>) in.readGenericValue(); + } + + public RankEvalResponse(String specId, double qualityLevel, Map> unknownDocs) { + this.specId = specId; + this.qualityLevel = qualityLevel; + this.unknownDocs = unknownDocs; + } + + public String getSpecId() { + return specId; + } + + public double getQualityLevel() { + return qualityLevel; + } + + public Map> getUnknownDocs() { + return unknownDocs; + } + + @Override + public String toString() { + return "RankEvalResult, ID :[" + specId + "], quality: " + qualityLevel + ", unknown docs: " + unknownDocs; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - qualityResult.writeTo(out); - } - - public void setRankEvalResult(RankEvalResult result) { - this.qualityResult = result; - } - - public RankEvalResult getRankEvalResult() { - return qualityResult; + out.writeString(specId); + out.writeDouble(qualityLevel); + out.writeGenericValue(getUnknownDocs()); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject("rank_eval"); - builder.field("spec_id", qualityResult.getSpecId()); - builder.field("quality_level", qualityResult.getQualityLevel()); + builder.field("spec_id", specId); + builder.field("quality_level", qualityLevel); builder.startArray("unknown_docs"); - Map> unknownDocs = qualityResult.getUnknownDocs(); for (String key : unknownDocs.keySet()) { builder.startObject(); builder.field(key, unknownDocs.get(key)); diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResult.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResult.java deleted file mode 100644 index 726b3c82aa7..00000000000 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalResult.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.index.rankeval; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; - -import java.io.IOException; -import java.util.Collection; -import java.util.Map; - -/** - * For each precision at n computation the id of the search request specification used to generate search requests is returned - * for reference. In addition the averaged precision and the ids of all documents returned but not found annotated is returned. - * */ -// TODO do we need an extra class for this or it RankEvalResponse enough? -// TODO instead of just returning averages over complete results, think of other statistics, micro avg, macro avg, partial results -public class RankEvalResult implements Writeable { - /**ID of QA specification this result was generated for.*/ - private String specId; - /**Average precision observed when issuing query intents with this specification.*/ - private double qualityLevel; - /**Mapping from intent id to all documents seen for this intent that were not annotated.*/ - private Map> unknownDocs; - - @SuppressWarnings("unchecked") - public RankEvalResult(StreamInput in) throws IOException { - this.specId = in.readString(); - this.qualityLevel = in.readDouble(); - this.unknownDocs = (Map>) in.readGenericValue(); - } - - public RankEvalResult(String specId, double quality, Map> unknownDocs) { - this.specId = specId; - this.qualityLevel = quality; - this.unknownDocs = unknownDocs; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(specId); - out.writeDouble(qualityLevel); - out.writeGenericValue(getUnknownDocs()); - } - - public String getSpecId() { - return specId; - } - - public double getQualityLevel() { - return qualityLevel; - } - - public Map> getUnknownDocs() { - return unknownDocs; - } - - @Override - public String toString() { - return "RankEvalResult, ID :[" + specId + "], quality: " + qualityLevel + ", unknown docs: " + unknownDocs; - } -} \ No newline at end of file diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java index 7b13014e310..f2e97068e62 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java @@ -105,10 +105,8 @@ public class TransportRankEvalAction extends HandledTransportAction>> entrySet = result.getUnknownDocs().entrySet(); + assertEquals(specId, response.getSpecId()); + assertEquals(1.0, response.getQualityLevel(), Double.MIN_VALUE); + Set>> entrySet = response.getUnknownDocs().entrySet(); assertEquals(2, entrySet.size()); for (Entry> entry : entrySet) { if (entry.getKey() == "amsterdam_query") { From 87e13ca8bbb6fdb409aa601d6b74cc400951004e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Fri, 8 Jul 2016 15:11:52 +0200 Subject: [PATCH 2/3] Add Discounted Cumulative Gain metric --- .../rankeval/DiscountedCumulativeGainAtN.java | 118 ++++++++++++++++++ .../rankeval/RankedListQualityMetric.java | 3 + .../DiscountedCumulativeGainAtNTests.java | 70 +++++++++++ .../rest-api-spec/test/rank_eval/20_dcg.yaml | 105 ++++++++++++++++ 4 files changed, 296 insertions(+) create mode 100644 modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtN.java create mode 100644 modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtNTests.java create mode 100644 modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/20_dcg.yaml diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtN.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtN.java new file mode 100644 index 00000000000..c1bed32952b --- /dev/null +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtN.java @@ -0,0 +1,118 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.rankeval; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParseFieldMatcherSupplier; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchHit; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class DiscountedCumulativeGainAtN extends RankedListQualityMetric { + + /** Number of results to check against a given set of relevant results. */ + private int n; + + public static final String NAME = "dcg_at_n"; + private static final double LOG2 = Math.log(2.0); + + public DiscountedCumulativeGainAtN(StreamInput in) throws IOException { + n = in.readInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(n); + } + + @Override + public String getWriteableName() { + return NAME; + } + + /** + * Initialises n with 10 + * */ + public DiscountedCumulativeGainAtN() { + this.n = 10; + } + + /** + * @param n number of top results to check against a given set of relevant results. Must be positive. + */ + public DiscountedCumulativeGainAtN(int n) { + if (n <= 0) { + throw new IllegalArgumentException("number of results to check needs to be positive but was " + n); + } + this.n = n; + } + + /** + * Return number of search results to check for quality. + */ + public int getN() { + return n; + } + + @Override + public EvalQueryQuality evaluate(SearchHit[] hits, List ratedDocs) { + Map ratedDocsById = new HashMap<>(); + for (RatedDocument doc : ratedDocs) { + ratedDocsById.put(doc.getDocID(), doc); + } + + Collection unknownDocIds = new ArrayList(); + double dcg = 0; + + for (int i = 0; (i < n && i < hits.length); i++) { + int rank = i + 1; // rank is 1-based + String id = hits[i].getId(); + RatedDocument ratedDoc = ratedDocsById.get(id); + if (ratedDoc != null) { + int rel = ratedDoc.getRating(); + dcg += (Math.pow(2, rel) - 1) / ((Math.log(rank + 1) / LOG2)); + } else { + unknownDocIds.add(id); + } + } + return new EvalQueryQuality(dcg, unknownDocIds); + } + + private static final ParseField SIZE_FIELD = new ParseField("size"); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("dcg_at", a -> new DiscountedCumulativeGainAtN((Integer) a[0])); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), SIZE_FIELD); + } + + public static DiscountedCumulativeGainAtN fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { + return PARSER.apply(parser, matcher); + } +} diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java index 1bfd6e516f5..ba28b6f9bb0 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java @@ -63,6 +63,9 @@ public abstract class RankedListQualityMetric implements NamedWriteable { case ReciprocalRank.NAME: rc = ReciprocalRank.fromXContent(parser, context); break; + case DiscountedCumulativeGainAtN.NAME: + rc = DiscountedCumulativeGainAtN.fromXContent(parser, context); + break; default: throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName); } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtNTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtNTests.java new file mode 100644 index 00000000000..2114feaa835 --- /dev/null +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtNTests.java @@ -0,0 +1,70 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.rankeval; + +import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.text.Text; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.internal.InternalSearchHit; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; + +public class DiscountedCumulativeGainAtNTests extends ESTestCase { + + /** + * Assuming the docs are ranked in the following order: + * + * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) + * ------------------------------------------------------------------------------------------- + * 1 | 3 | 7.0 | 1.0 | 7.0 + * 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 + * 3 | 3 | 7.0 | 2.0 | 3.5 + * 4 | 0 | 0.0 | 2.321928094887362 | 0.0 + * 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163 + * 6 | 2 | 3.0 | 2.807354922057604 | 1.0686215613240666 + */ + public void testDCGAtSix() throws IOException, InterruptedException, ExecutionException { + List rated = new ArrayList<>(); + int[] relevanceRatings = new int[] { 3, 2, 3, 0, 1, 2 }; + SearchHit[] hits = new InternalSearchHit[6]; + for (int i = 0; i < 6; i++) { + rated.add(new RatedDocument(Integer.toString(i), relevanceRatings[i])); + hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap()); + } + assertEquals(13.84826362927298d, (new DiscountedCumulativeGainAtN(6)).evaluate(hits, rated).getQualityLevel(), 0.00001); + } + + + public void testParseFromXContent() throws IOException { + String xContent = " {\n" + + " \"size\": 8\n" + + "}"; + XContentParser parser = XContentFactory.xContent(xContent).createParser(xContent); + DiscountedCumulativeGainAtN dcgAt = DiscountedCumulativeGainAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT); + assertEquals(8, dcgAt.getN()); + } +} diff --git a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/20_dcg.yaml b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/20_dcg.yaml new file mode 100644 index 00000000000..6fb28286138 --- /dev/null +++ b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/20_dcg.yaml @@ -0,0 +1,105 @@ +--- +"Response format": + + - do: + index: + index: foo + type: bar + id: doc1 + body: { "bar": 1 } + + - do: + index: + index: foo + type: bar + id: doc2 + body: { "bar": 2 } + + - do: + index: + index: foo + type: bar + id: doc3 + body: { "bar": 3 } + + - do: + index: + index: foo + type: bar + id: doc4 + body: { "bar": 4 } + - do: + index: + index: foo + type: bar + id: doc5 + body: { "bar": 5 } + - do: + index: + index: foo + type: bar + id: doc6 + body: { "bar": 6 } + + - do: + indices.refresh: {} + + - do: + rank_eval: + body: { + "spec_id" : "dcg_qa_queries", + "requests" : [ + { + "id": "dcg_query", + "request": { "query": { "match_all" : {}}, "sort" : [ "bar" ] }, + "ratings": [{ "doc1": 3}, {"doc2": 2}, {"doc3": 3}, {"doc4": 0}, {"doc5": 1}, {"doc6": 2}] + } + ], + "metric" : { "dcg_at_n": { "size": 6}} + } + + - match: {rank_eval.spec_id: "dcg_qa_queries"} + - match: {rank_eval.quality_level: 13.84826362927298} + +# reverse the order in which the results are returned (less relevant docs first) + + - do: + rank_eval: + body: { + "spec_id" : "dcg_qa_queries", + "requests" : [ + { + "id": "dcg_query_reverse", + "request": { "query": { "match_all" : {}}, "sort" : [ {"bar" : "desc" }] }, + "ratings": [{ "doc1": 3}, {"doc2": 2}, {"doc3": 3}, {"doc4": 0}, {"doc5": 1}, {"doc6": 2}] + }, + ], + "metric" : { "dcg_at_n": { "size": 6}} + } + + - match: {rank_eval.spec_id: "dcg_qa_queries"} + - match: {rank_eval.quality_level: 10.29967439154499} + +# if we mix both, we should get the average + + - do: + rank_eval: + body: { + "spec_id" : "dcg_qa_queries", + "requests" : [ + { + "id": "dcg_query", + "request": { "query": { "match_all" : {}}, "sort" : [ "bar" ] }, + "ratings": [{ "doc1": 3}, {"doc2": 2}, {"doc3": 3}, {"doc4": 0}, {"doc5": 1}, {"doc6": 2}] + }, + { + "id": "dcg_query_reverse", + "request": { "query": { "match_all" : {}}, "sort" : [ {"bar" : "desc" }] }, + "ratings": [{ "doc1": 3}, {"doc2": 2}, {"doc3": 3}, {"doc4": 0}, {"doc5": 1}, {"doc6": 2}] + }, + ], + "metric" : { "dcg_at_n": { "size": 6}} + } + + - match: {rank_eval.spec_id: "dcg_qa_queries"} + - match: {rank_eval.quality_level: 12.073969010408984} From fa459f88ddb08e2bf18fe8b6cdc720f6dfc79a71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Wed, 27 Jul 2016 15:22:14 +0200 Subject: [PATCH 3/3] Add normalization option When switched on, compute the normalized ndcg variant. --- .../rankeval/DiscountedCumulativeGainAt.java | 183 ++++++++++++++++++ .../rankeval/DiscountedCumulativeGainAtN.java | 118 ----------- .../rankeval/RankedListQualityMetric.java | 4 +- .../DiscountedCumulativeGainAtNTests.java | 70 ------- .../DiscountedCumulativeGainAtTests.java | 121 ++++++++++++ .../index/rankeval/ReciprocalRankTests.java | 14 +- 6 files changed, 316 insertions(+), 194 deletions(-) create mode 100644 modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java delete mode 100644 modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtN.java delete mode 100644 modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtNTests.java create mode 100644 modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java new file mode 100644 index 00000000000..2db310b185d --- /dev/null +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAt.java @@ -0,0 +1,183 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.rankeval; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParseFieldMatcherSupplier; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchHit; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class DiscountedCumulativeGainAt extends RankedListQualityMetric { + + /** rank position up to which to check results. */ + private int position; + /** If set to true, the dcg will be normalized (ndcg) */ + private boolean normalize; + /** If set to, this will be the rating for docs the user hasn't supplied an explicit rating for */ + private Integer unknownDocRating; + + public static final String NAME = "dcg_at_n"; + private static final double LOG2 = Math.log(2.0); + + public DiscountedCumulativeGainAt(StreamInput in) throws IOException { + position = in.readInt(); + normalize = in.readBoolean(); + unknownDocRating = in.readOptionalVInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(position); + out.writeBoolean(normalize); + out.writeOptionalVInt(unknownDocRating); + } + + @Override + public String getWriteableName() { + return NAME; + } + + /** + * Initialises position with 10 + * */ + public DiscountedCumulativeGainAt() { + this.position = 10; + } + + /** + * @param position number of top results to check against a given set of relevant results. Must be positive. + */ + public DiscountedCumulativeGainAt(int position) { + if (position <= 0) { + throw new IllegalArgumentException("number of results to check needs to be positive but was " + position); + } + this.position = position; + } + + /** + * Return number of search results to check for quality metric. + */ + public int getPosition() { + return this.position; + } + + /** + * set number of search results to check for quality metric. + */ + public void setPosition(int position) { + this.position = position; + } + + /** + * If set to true, the dcg will be normalized (ndcg) + */ + public void setNormalize(boolean normalize) { + this.normalize = normalize; + } + + /** + * check whether this metric computes only dcg or "normalized" ndcg + */ + public boolean getNormalize() { + return this.normalize; + } + + /** + * the rating for docs the user hasn't supplied an explicit rating for + */ + public void setUnknownDocRating(int unknownDocRating) { + this.unknownDocRating = unknownDocRating; + } + + /** + * check whether this metric computes only dcg or "normalized" ndcg + */ + public Integer getUnknownDocRating() { + return this.unknownDocRating; + } + + @Override + public EvalQueryQuality evaluate(SearchHit[] hits, List ratedDocs) { + Map ratedDocsById = new HashMap<>(); + for (RatedDocument doc : ratedDocs) { + ratedDocsById.put(doc.getDocID(), doc); + } + + Collection unknownDocIds = new ArrayList<>(); + List ratings = new ArrayList<>(); + for (int i = 0; (i < position && i < hits.length); i++) { + String id = hits[i].getId(); + RatedDocument ratedDoc = ratedDocsById.get(id); + if (ratedDoc != null) { + ratings.add(ratedDoc.getRating()); + } else { + unknownDocIds.add(id); + if (unknownDocRating != null) { + ratings.add(unknownDocRating); + } + } + } + double dcg = computeDCG(ratings); + + if (normalize) { + Collections.sort(ratings, Collections.reverseOrder()); + double idcg = computeDCG(ratings); + dcg = dcg / idcg; + } + return new EvalQueryQuality(dcg, unknownDocIds); + } + + private static double computeDCG(List ratings) { + int rank = 1; + double dcg = 0; + for (int rating : ratings) { + dcg += (Math.pow(2, rating) - 1) / ((Math.log(rank + 1) / LOG2)); + rank++; + } + return dcg; + } + + private static final ParseField SIZE_FIELD = new ParseField("size"); + private static final ParseField NORMALIZE_FIELD = new ParseField("normalize"); + private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating"); + private static final ObjectParser PARSER = + new ObjectParser<>("dcg_at", () -> new DiscountedCumulativeGainAt()); + + static { + PARSER.declareInt(DiscountedCumulativeGainAt::setPosition, SIZE_FIELD); + PARSER.declareBoolean(DiscountedCumulativeGainAt::setNormalize, NORMALIZE_FIELD); + PARSER.declareInt(DiscountedCumulativeGainAt::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD); + } + + public static DiscountedCumulativeGainAt fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { + return PARSER.apply(parser, matcher); + } +} diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtN.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtN.java deleted file mode 100644 index c1bed32952b..00000000000 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtN.java +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.index.rankeval; - -import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.ParseFieldMatcherSupplier; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.search.SearchHit; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -public class DiscountedCumulativeGainAtN extends RankedListQualityMetric { - - /** Number of results to check against a given set of relevant results. */ - private int n; - - public static final String NAME = "dcg_at_n"; - private static final double LOG2 = Math.log(2.0); - - public DiscountedCumulativeGainAtN(StreamInput in) throws IOException { - n = in.readInt(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeInt(n); - } - - @Override - public String getWriteableName() { - return NAME; - } - - /** - * Initialises n with 10 - * */ - public DiscountedCumulativeGainAtN() { - this.n = 10; - } - - /** - * @param n number of top results to check against a given set of relevant results. Must be positive. - */ - public DiscountedCumulativeGainAtN(int n) { - if (n <= 0) { - throw new IllegalArgumentException("number of results to check needs to be positive but was " + n); - } - this.n = n; - } - - /** - * Return number of search results to check for quality. - */ - public int getN() { - return n; - } - - @Override - public EvalQueryQuality evaluate(SearchHit[] hits, List ratedDocs) { - Map ratedDocsById = new HashMap<>(); - for (RatedDocument doc : ratedDocs) { - ratedDocsById.put(doc.getDocID(), doc); - } - - Collection unknownDocIds = new ArrayList(); - double dcg = 0; - - for (int i = 0; (i < n && i < hits.length); i++) { - int rank = i + 1; // rank is 1-based - String id = hits[i].getId(); - RatedDocument ratedDoc = ratedDocsById.get(id); - if (ratedDoc != null) { - int rel = ratedDoc.getRating(); - dcg += (Math.pow(2, rel) - 1) / ((Math.log(rank + 1) / LOG2)); - } else { - unknownDocIds.add(id); - } - } - return new EvalQueryQuality(dcg, unknownDocIds); - } - - private static final ParseField SIZE_FIELD = new ParseField("size"); - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("dcg_at", a -> new DiscountedCumulativeGainAtN((Integer) a[0])); - - static { - PARSER.declareInt(ConstructingObjectParser.constructorArg(), SIZE_FIELD); - } - - public static DiscountedCumulativeGainAtN fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) { - return PARSER.apply(parser, matcher); - } -} diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java index ba28b6f9bb0..c346f0d0d59 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java @@ -63,8 +63,8 @@ public abstract class RankedListQualityMetric implements NamedWriteable { case ReciprocalRank.NAME: rc = ReciprocalRank.fromXContent(parser, context); break; - case DiscountedCumulativeGainAtN.NAME: - rc = DiscountedCumulativeGainAtN.fromXContent(parser, context); + case DiscountedCumulativeGainAt.NAME: + rc = DiscountedCumulativeGainAt.fromXContent(parser, context); break; default: throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtNTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtNTests.java deleted file mode 100644 index 2114feaa835..00000000000 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtNTests.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.index.rankeval; - -import org.elasticsearch.common.ParseFieldMatcher; -import org.elasticsearch.common.text.Text; -import org.elasticsearch.common.xcontent.XContentFactory; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.internal.InternalSearchHit; -import org.elasticsearch.test.ESTestCase; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.ExecutionException; - -public class DiscountedCumulativeGainAtNTests extends ESTestCase { - - /** - * Assuming the docs are ranked in the following order: - * - * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) - * ------------------------------------------------------------------------------------------- - * 1 | 3 | 7.0 | 1.0 | 7.0 - * 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 - * 3 | 3 | 7.0 | 2.0 | 3.5 - * 4 | 0 | 0.0 | 2.321928094887362 | 0.0 - * 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163 - * 6 | 2 | 3.0 | 2.807354922057604 | 1.0686215613240666 - */ - public void testDCGAtSix() throws IOException, InterruptedException, ExecutionException { - List rated = new ArrayList<>(); - int[] relevanceRatings = new int[] { 3, 2, 3, 0, 1, 2 }; - SearchHit[] hits = new InternalSearchHit[6]; - for (int i = 0; i < 6; i++) { - rated.add(new RatedDocument(Integer.toString(i), relevanceRatings[i])); - hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap()); - } - assertEquals(13.84826362927298d, (new DiscountedCumulativeGainAtN(6)).evaluate(hits, rated).getQualityLevel(), 0.00001); - } - - - public void testParseFromXContent() throws IOException { - String xContent = " {\n" - + " \"size\": 8\n" - + "}"; - XContentParser parser = XContentFactory.xContent(xContent).createParser(xContent); - DiscountedCumulativeGainAtN dcgAt = DiscountedCumulativeGainAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT); - assertEquals(8, dcgAt.getN()); - } -} diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java new file mode 100644 index 00000000000..59b05a35ade --- /dev/null +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainAtTests.java @@ -0,0 +1,121 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.rankeval; + +import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.text.Text; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.internal.InternalSearchHit; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; + +public class DiscountedCumulativeGainAtTests extends ESTestCase { + + /** + * Assuming the docs are ranked in the following order: + * + * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) + * ------------------------------------------------------------------------------------------- + * 1 | 3 | 7.0 | 1.0 | 7.0 + * 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 + * 3 | 3 | 7.0 | 2.0 | 3.5 + * 4 | 0 | 0.0 | 2.321928094887362 | 0.0 + * 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163 + * 6 | 2 | 3.0 | 2.807354922057604 | 1.0686215613240666 + * + * dcg = 13.84826362927298 (sum of last column) + */ + public void testDCGAtSix() throws IOException, InterruptedException, ExecutionException { + List rated = new ArrayList<>(); + int[] relevanceRatings = new int[] { 3, 2, 3, 0, 1, 2 }; + SearchHit[] hits = new InternalSearchHit[6]; + for (int i = 0; i < 6; i++) { + rated.add(new RatedDocument(Integer.toString(i), relevanceRatings[i])); + hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap()); + } + DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(6); + assertEquals(13.84826362927298, dcg.evaluate(hits, rated).getQualityLevel(), 0.00001); + + /** + * Check with normalization: to get the maximal possible dcg, sort documents by relevance in descending order + * + * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) + * ------------------------------------------------------------------------------------------- + * 1 | 3 | 7.0 | 1.0  | 7.0 + * 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202 + * 3 | 2 | 3.0 | 2.0  | 1.5 + * 4 | 2 | 3.0 | 2.321928094887362  | 1.2920296742201793 + * 5 | 1 | 1.0 | 2.584962500721156  | 0.38685280723454163 + * 6 | 0 | 0.0 | 2.807354922057604  | 0.0 + * + * idcg = 14.595390756454922 (sum of last column) + */ + dcg.setNormalize(true); + assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate(hits, rated).getQualityLevel(), 0.00001); + } + + /** + * This tests metric when some documents in the search result don't have a rating provided by the user. + * + * rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1) + * ------------------------------------------------------------------------------------------- + * 1 | 3 | 7.0 | 1.0 | 7.0 + * 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 + * 3 | 3 | 7.0 | 2.0 | 3.5 + * 4 | n/a | n/a | n/a | n/a + * 5 | n/a | n/a | n/a | n/a + * 6 | n/a | n/a | n/a | n/a + * + * dcg = 13.84826362927298 (sum of last column) + */ + public void testDCGAtSixMissingRatings() throws IOException, InterruptedException, ExecutionException { + List rated = new ArrayList<>(); + int[] relevanceRatings = new int[] { 3, 2, 3}; + SearchHit[] hits = new InternalSearchHit[6]; + for (int i = 0; i < 6; i++) { + if (i < relevanceRatings.length) { + rated.add(new RatedDocument(Integer.toString(i), relevanceRatings[i])); + } + hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap()); + } + DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(6); + EvalQueryQuality result = dcg.evaluate(hits, rated); + assertEquals(12.392789260714371, result.getQualityLevel(), 0.00001); + assertEquals(3, result.getUnknownDocs().size()); + } + + public void testParseFromXContent() throws IOException { + String xContent = " {\n" + + " \"size\": 8,\n" + + " \"normalize\": true\n" + + "}"; + XContentParser parser = XContentFactory.xContent(xContent).createParser(xContent); + DiscountedCumulativeGainAt dcgAt = DiscountedCumulativeGainAt.fromXContent(parser, () -> ParseFieldMatcher.STRICT); + assertEquals(8, dcgAt.getPosition()); + assertEquals(true, dcgAt.getNormalize()); + } +} diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java index 12dd808cff7..e87905cb1b4 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java @@ -58,11 +58,17 @@ public class ReciprocalRankTests extends ESTestCase { int rankAtFirstRelevant = relevantAt + 1; EvalQueryQuality evaluation = reciprocalRank.evaluate(hits, ratedDocs); - assertEquals(1.0 / rankAtFirstRelevant, evaluation.getQualityLevel(), Double.MIN_VALUE); + if (rankAtFirstRelevant <= maxRank) { + assertEquals(1.0 / rankAtFirstRelevant, evaluation.getQualityLevel(), Double.MIN_VALUE); - reciprocalRank = new ReciprocalRank(rankAtFirstRelevant - 1); - evaluation = reciprocalRank.evaluate(hits, ratedDocs); - assertEquals(0.0, evaluation.getQualityLevel(), Double.MIN_VALUE); + // check that if we lower maxRank by one, we don't find any result and get 0.0 quality level + reciprocalRank = new ReciprocalRank(rankAtFirstRelevant - 1); + evaluation = reciprocalRank.evaluate(hits, ratedDocs); + assertEquals(0.0, evaluation.getQualityLevel(), Double.MIN_VALUE); + + } else { + assertEquals(0.0, evaluation.getQualityLevel(), Double.MIN_VALUE); + } } public void testEvaluationOneRelevantInResults() {