Add Discounted Cumulative Gain metric

This commit is contained in:
Christoph Büscher 2016-07-08 15:11:52 +02:00
parent 0578a96483
commit 87e13ca8bb
4 changed files with 296 additions and 0 deletions

View File

@ -0,0 +1,118 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchHit;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class DiscountedCumulativeGainAtN extends RankedListQualityMetric {
/** Number of results to check against a given set of relevant results. */
private int n;
public static final String NAME = "dcg_at_n";
private static final double LOG2 = Math.log(2.0);
public DiscountedCumulativeGainAtN(StreamInput in) throws IOException {
n = in.readInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(n);
}
@Override
public String getWriteableName() {
return NAME;
}
/**
* Initialises n with 10
* */
public DiscountedCumulativeGainAtN() {
this.n = 10;
}
/**
* @param n number of top results to check against a given set of relevant results. Must be positive.
*/
public DiscountedCumulativeGainAtN(int n) {
if (n <= 0) {
throw new IllegalArgumentException("number of results to check needs to be positive but was " + n);
}
this.n = n;
}
/**
* Return number of search results to check for quality.
*/
public int getN() {
return n;
}
@Override
public EvalQueryQuality evaluate(SearchHit[] hits, List<RatedDocument> ratedDocs) {
Map<String, RatedDocument> ratedDocsById = new HashMap<>();
for (RatedDocument doc : ratedDocs) {
ratedDocsById.put(doc.getDocID(), doc);
}
Collection<String> unknownDocIds = new ArrayList<String>();
double dcg = 0;
for (int i = 0; (i < n && i < hits.length); i++) {
int rank = i + 1; // rank is 1-based
String id = hits[i].getId();
RatedDocument ratedDoc = ratedDocsById.get(id);
if (ratedDoc != null) {
int rel = ratedDoc.getRating();
dcg += (Math.pow(2, rel) - 1) / ((Math.log(rank + 1) / LOG2));
} else {
unknownDocIds.add(id);
}
}
return new EvalQueryQuality(dcg, unknownDocIds);
}
private static final ParseField SIZE_FIELD = new ParseField("size");
private static final ConstructingObjectParser<DiscountedCumulativeGainAtN, ParseFieldMatcherSupplier> PARSER =
new ConstructingObjectParser<>("dcg_at", a -> new DiscountedCumulativeGainAtN((Integer) a[0]));
static {
PARSER.declareInt(ConstructingObjectParser.constructorArg(), SIZE_FIELD);
}
public static DiscountedCumulativeGainAtN fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) {
return PARSER.apply(parser, matcher);
}
}

View File

@ -63,6 +63,9 @@ public abstract class RankedListQualityMetric implements NamedWriteable {
case ReciprocalRank.NAME:
rc = ReciprocalRank.fromXContent(parser, context);
break;
case DiscountedCumulativeGainAtN.NAME:
rc = DiscountedCumulativeGainAtN.fromXContent(parser, context);
break;
default:
throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName);
}

View File

@ -0,0 +1,70 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.internal.InternalSearchHit;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;
public class DiscountedCumulativeGainAtNTests extends ESTestCase {
/**
* Assuming the docs are ranked in the following order:
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* -------------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0 | 7.0
* 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5
* 4 | 0 | 0.0 | 2.321928094887362 | 0.0
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
* 6 | 2 | 3.0 | 2.807354922057604 | 1.0686215613240666
*/
public void testDCGAtSix() throws IOException, InterruptedException, ExecutionException {
List<RatedDocument> rated = new ArrayList<>();
int[] relevanceRatings = new int[] { 3, 2, 3, 0, 1, 2 };
SearchHit[] hits = new InternalSearchHit[6];
for (int i = 0; i < 6; i++) {
rated.add(new RatedDocument(Integer.toString(i), relevanceRatings[i]));
hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap());
}
assertEquals(13.84826362927298d, (new DiscountedCumulativeGainAtN(6)).evaluate(hits, rated).getQualityLevel(), 0.00001);
}
public void testParseFromXContent() throws IOException {
String xContent = " {\n"
+ " \"size\": 8\n"
+ "}";
XContentParser parser = XContentFactory.xContent(xContent).createParser(xContent);
DiscountedCumulativeGainAtN dcgAt = DiscountedCumulativeGainAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
assertEquals(8, dcgAt.getN());
}
}

View File

@ -0,0 +1,105 @@
---
"Response format":
- do:
index:
index: foo
type: bar
id: doc1
body: { "bar": 1 }
- do:
index:
index: foo
type: bar
id: doc2
body: { "bar": 2 }
- do:
index:
index: foo
type: bar
id: doc3
body: { "bar": 3 }
- do:
index:
index: foo
type: bar
id: doc4
body: { "bar": 4 }
- do:
index:
index: foo
type: bar
id: doc5
body: { "bar": 5 }
- do:
index:
index: foo
type: bar
id: doc6
body: { "bar": 6 }
- do:
indices.refresh: {}
- do:
rank_eval:
body: {
"spec_id" : "dcg_qa_queries",
"requests" : [
{
"id": "dcg_query",
"request": { "query": { "match_all" : {}}, "sort" : [ "bar" ] },
"ratings": [{ "doc1": 3}, {"doc2": 2}, {"doc3": 3}, {"doc4": 0}, {"doc5": 1}, {"doc6": 2}]
}
],
"metric" : { "dcg_at_n": { "size": 6}}
}
- match: {rank_eval.spec_id: "dcg_qa_queries"}
- match: {rank_eval.quality_level: 13.84826362927298}
# reverse the order in which the results are returned (less relevant docs first)
- do:
rank_eval:
body: {
"spec_id" : "dcg_qa_queries",
"requests" : [
{
"id": "dcg_query_reverse",
"request": { "query": { "match_all" : {}}, "sort" : [ {"bar" : "desc" }] },
"ratings": [{ "doc1": 3}, {"doc2": 2}, {"doc3": 3}, {"doc4": 0}, {"doc5": 1}, {"doc6": 2}]
},
],
"metric" : { "dcg_at_n": { "size": 6}}
}
- match: {rank_eval.spec_id: "dcg_qa_queries"}
- match: {rank_eval.quality_level: 10.29967439154499}
# if we mix both, we should get the average
- do:
rank_eval:
body: {
"spec_id" : "dcg_qa_queries",
"requests" : [
{
"id": "dcg_query",
"request": { "query": { "match_all" : {}}, "sort" : [ "bar" ] },
"ratings": [{ "doc1": 3}, {"doc2": 2}, {"doc3": 3}, {"doc4": 0}, {"doc5": 1}, {"doc6": 2}]
},
{
"id": "dcg_query_reverse",
"request": { "query": { "match_all" : {}}, "sort" : [ {"bar" : "desc" }] },
"ratings": [{ "doc1": 3}, {"doc2": 2}, {"doc3": 3}, {"doc4": 0}, {"doc5": 1}, {"doc6": 2}]
},
],
"metric" : { "dcg_at_n": { "size": 6}}
}
- match: {rank_eval.spec_id: "dcg_qa_queries"}
- match: {rank_eval.quality_level: 12.073969010408984}