Removing the 'size' parameter from the dcg metric
This commit is contained in:
parent
3855d7f721
commit
c8d9d063ca
|
@ -38,56 +38,36 @@ import java.util.stream.Collectors;
|
|||
|
||||
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings;
|
||||
|
||||
public class DiscountedCumulativeGainAt implements RankedListQualityMetric {
|
||||
public class DiscountedCumulativeGain implements RankedListQualityMetric {
|
||||
|
||||
/** rank position up to which to check results. */
|
||||
private int position;
|
||||
/** If set to true, the dcg will be normalized (ndcg) */
|
||||
private boolean normalize;
|
||||
/** If set to, this will be the rating for docs the user hasn't supplied an explicit rating for */
|
||||
private Integer unknownDocRating;
|
||||
|
||||
public static final String NAME = "dcg_at_n";
|
||||
public static final String NAME = "dcg";
|
||||
private static final double LOG2 = Math.log(2.0);
|
||||
|
||||
/**
|
||||
* Initialises position with 10
|
||||
* */
|
||||
public DiscountedCumulativeGainAt() {
|
||||
this.position = 10;
|
||||
public DiscountedCumulativeGain() {
|
||||
}
|
||||
|
||||
/**
|
||||
* @param position number of top results to check against a given set of relevant results. Must be positive.
|
||||
*/
|
||||
public DiscountedCumulativeGainAt(int position) {
|
||||
if (position <= 0) {
|
||||
throw new IllegalArgumentException("number of results to check needs to be positive but was " + position);
|
||||
}
|
||||
this.position = position;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param position number of top results to check against a given set of relevant results. Must be positive.
|
||||
* @param normalize If set to true, dcg will be normalized (ndcg)
|
||||
* See https://en.wikipedia.org/wiki/Discounted_cumulative_gain
|
||||
* @param unknownDocRating the rating for docs the user hasn't supplied an explicit rating for
|
||||
* */
|
||||
public DiscountedCumulativeGainAt(int position, boolean normalize, Integer unknownDocRating) {
|
||||
this(position);
|
||||
public DiscountedCumulativeGain( boolean normalize, Integer unknownDocRating) {
|
||||
this.normalize = normalize;
|
||||
this.unknownDocRating = unknownDocRating;
|
||||
}
|
||||
|
||||
public DiscountedCumulativeGainAt(StreamInput in) throws IOException {
|
||||
this(in.readInt());
|
||||
public DiscountedCumulativeGain(StreamInput in) throws IOException {
|
||||
normalize = in.readBoolean();
|
||||
unknownDocRating = in.readOptionalVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeInt(position);
|
||||
out.writeBoolean(normalize);
|
||||
out.writeOptionalVInt(unknownDocRating);
|
||||
}
|
||||
|
@ -97,20 +77,6 @@ public class DiscountedCumulativeGainAt implements RankedListQualityMetric {
|
|||
return NAME;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return number of search results to check for quality metric.
|
||||
*/
|
||||
public int getPosition() {
|
||||
return this.position;
|
||||
}
|
||||
|
||||
/**
|
||||
* set number of search results to check for quality metric.
|
||||
*/
|
||||
public void setPosition(int position) {
|
||||
this.position = position;
|
||||
}
|
||||
|
||||
/**
|
||||
* If set to true, the dcg will be normalized (ndcg)
|
||||
*/
|
||||
|
@ -143,8 +109,8 @@ public class DiscountedCumulativeGainAt implements RankedListQualityMetric {
|
|||
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs) {
|
||||
List<Integer> allRatings = ratedDocs.stream().mapToInt(RatedDocument::getRating).boxed().collect(Collectors.toList());
|
||||
List<RatedSearchHit> ratedHits = joinHitsWithRatings(hits, ratedDocs);
|
||||
List<Integer> ratingsInSearchHits = new ArrayList<>(Math.min(ratedHits.size(), position));
|
||||
for (RatedSearchHit hit : ratedHits.subList(0, position)) {
|
||||
List<Integer> ratingsInSearchHits = new ArrayList<>(ratedHits.size());
|
||||
for (RatedSearchHit hit : ratedHits) {
|
||||
// unknownDocRating might be null, which means it will be unrated docs are ignored in the dcg calculation
|
||||
// we still need to add them as a placeholder so the rank of the subsequent ratings is correct
|
||||
ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating));
|
||||
|
@ -173,19 +139,17 @@ public class DiscountedCumulativeGainAt implements RankedListQualityMetric {
|
|||
return dcg;
|
||||
}
|
||||
|
||||
private static final ParseField SIZE_FIELD = new ParseField("size");
|
||||
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
|
||||
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
|
||||
private static final ObjectParser<DiscountedCumulativeGainAt, ParseFieldMatcherSupplier> PARSER =
|
||||
new ObjectParser<>("dcg_at", () -> new DiscountedCumulativeGainAt());
|
||||
private static final ObjectParser<DiscountedCumulativeGain, ParseFieldMatcherSupplier> PARSER =
|
||||
new ObjectParser<>("dcg_at", () -> new DiscountedCumulativeGain());
|
||||
|
||||
static {
|
||||
PARSER.declareInt(DiscountedCumulativeGainAt::setPosition, SIZE_FIELD);
|
||||
PARSER.declareBoolean(DiscountedCumulativeGainAt::setNormalize, NORMALIZE_FIELD);
|
||||
PARSER.declareInt(DiscountedCumulativeGainAt::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD);
|
||||
PARSER.declareBoolean(DiscountedCumulativeGain::setNormalize, NORMALIZE_FIELD);
|
||||
PARSER.declareInt(DiscountedCumulativeGain::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD);
|
||||
}
|
||||
|
||||
public static DiscountedCumulativeGainAt fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) {
|
||||
public static DiscountedCumulativeGain fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) {
|
||||
return PARSER.apply(parser, matcher);
|
||||
}
|
||||
|
||||
|
@ -193,7 +157,6 @@ public class DiscountedCumulativeGainAt implements RankedListQualityMetric {
|
|||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.startObject(NAME);
|
||||
builder.field(SIZE_FIELD.getPreferredName(), this.position);
|
||||
builder.field(NORMALIZE_FIELD.getPreferredName(), this.normalize);
|
||||
if (unknownDocRating != null) {
|
||||
builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating);
|
||||
|
@ -211,15 +174,14 @@ public class DiscountedCumulativeGainAt implements RankedListQualityMetric {
|
|||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
DiscountedCumulativeGainAt other = (DiscountedCumulativeGainAt) obj;
|
||||
return Objects.equals(position, other.position) &&
|
||||
Objects.equals(normalize, other.normalize) &&
|
||||
DiscountedCumulativeGain other = (DiscountedCumulativeGain) obj;
|
||||
return Objects.equals(normalize, other.normalize) &&
|
||||
Objects.equals(unknownDocRating, other.unknownDocRating);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(position, normalize, unknownDocRating);
|
||||
return Objects.hash(normalize, unknownDocRating);
|
||||
}
|
||||
|
||||
// TODO maybe also add debugging breakdown here
|
|
@ -52,8 +52,8 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
|
|||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, Precision.NAME, Precision::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, ReciprocalRank.NAME, ReciprocalRank::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGainAt.NAME,
|
||||
DiscountedCumulativeGainAt::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGain.NAME,
|
||||
DiscountedCumulativeGain::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, Precision.NAME,
|
||||
Precision.Breakdown::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, ReciprocalRank.NAME,
|
||||
|
|
|
@ -70,8 +70,8 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
|
|||
case ReciprocalRank.NAME:
|
||||
rc = ReciprocalRank.fromXContent(parser, context);
|
||||
break;
|
||||
case DiscountedCumulativeGainAt.NAME:
|
||||
rc = DiscountedCumulativeGainAt.fromXContent(parser, context);
|
||||
case DiscountedCumulativeGain.NAME:
|
||||
rc = DiscountedCumulativeGain.fromXContent(parser, context);
|
||||
break;
|
||||
default:
|
||||
throw new ParsingException(parser.getTokenLocation(), "[_na] unknown query metric name [{}]", metricName);
|
||||
|
|
|
@ -45,7 +45,7 @@ import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsW
|
|||
public class ReciprocalRank implements RankedListQualityMetric {
|
||||
|
||||
public static final String NAME = "reciprocal_rank";
|
||||
public static final int DEFAULT_MAX_ACCEPTABLE_RANK = 10;
|
||||
public static final int DEFAULT_MAX_ACCEPTABLE_RANK = Integer.MAX_VALUE;
|
||||
private int maxAcceptableRank = DEFAULT_MAX_ACCEPTABLE_RANK;
|
||||
|
||||
/** ratings equal or above this value will be considered relevant. */
|
||||
|
|
|
@ -30,14 +30,12 @@ import org.elasticsearch.test.ESTestCase;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.filterUnknownDocuments;
|
||||
|
||||
public class DiscountedCumulativeGainAtTests extends ESTestCase {
|
||||
public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||
|
||||
/**
|
||||
* Assuming the docs are ranked in the following order:
|
||||
|
@ -53,7 +51,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
|
|||
*
|
||||
* dcg = 13.84826362927298 (sum of last column)
|
||||
*/
|
||||
public void testDCGAt() throws IOException, InterruptedException, ExecutionException {
|
||||
public void testDCGAt() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
int[] relevanceRatings = new int[] { 3, 2, 3, 0, 1, 2 };
|
||||
InternalSearchHit[] hits = new InternalSearchHit[6];
|
||||
|
@ -62,7 +60,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
|
|||
hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap());
|
||||
hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0)));
|
||||
}
|
||||
DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(6);
|
||||
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
||||
assertEquals(13.84826362927298, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
|
||||
|
||||
/**
|
||||
|
@ -97,7 +95,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
|
|||
*
|
||||
* dcg = 12.779642067948913 (sum of last column)
|
||||
*/
|
||||
public void testDCGAtSixMissingRatings() throws IOException, InterruptedException, ExecutionException {
|
||||
public void testDCGAtSixMissingRatings() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
Integer[] relevanceRatings = new Integer[] { 3, 2, 3, null, 1};
|
||||
InternalSearchHit[] hits = new InternalSearchHit[6];
|
||||
|
@ -110,7 +108,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
|
|||
hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap());
|
||||
hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0)));
|
||||
}
|
||||
DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(6);
|
||||
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
||||
EvalQueryQuality result = dcg.evaluate("id", hits, rated);
|
||||
assertEquals(12.779642067948913, result.getQualityLevel(), 0.00001);
|
||||
assertEquals(2, filterUnknownDocuments(result.getHitsAndRatings()).size());
|
||||
|
@ -149,21 +147,24 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
|
|||
*
|
||||
* dcg = 12.392789260714371 (sum of last column until position 4)
|
||||
*/
|
||||
public void testDCGAtFourMoreRatings() throws IOException, InterruptedException, ExecutionException {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
public void testDCGAtFourMoreRatings() {
|
||||
Integer[] relevanceRatings = new Integer[] { 3, 2, 3, null, 1, null};
|
||||
InternalSearchHit[] hits = new InternalSearchHit[6];
|
||||
List<RatedDocument> ratedDocs = new ArrayList<>();
|
||||
for (int i = 0; i < 6; i++) {
|
||||
if (i < relevanceRatings.length) {
|
||||
if (relevanceRatings[i] != null) {
|
||||
rated.add(new RatedDocument("index", "type", Integer.toString(i), relevanceRatings[i]));
|
||||
ratedDocs.add(new RatedDocument("index", "type", Integer.toString(i), relevanceRatings[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
// only create four hits
|
||||
InternalSearchHit[] hits = new InternalSearchHit[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap());
|
||||
hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0)));
|
||||
}
|
||||
DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(4);
|
||||
EvalQueryQuality result = dcg.evaluate("id", Arrays.copyOfRange(hits, 0, 4), rated);
|
||||
DiscountedCumulativeGain dcg = new DiscountedCumulativeGain();
|
||||
EvalQueryQuality result = dcg.evaluate("id", hits, ratedDocs);
|
||||
assertEquals(12.392789260714371 , result.getQualityLevel(), 0.00001);
|
||||
assertEquals(1, filterUnknownDocuments(result.getHitsAndRatings()).size());
|
||||
|
||||
|
@ -183,33 +184,32 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
|
|||
* idcg = 13.347184833073591 (sum of last column)
|
||||
*/
|
||||
dcg.setNormalize(true);
|
||||
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
|
||||
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001);
|
||||
}
|
||||
|
||||
public void testParseFromXContent() throws IOException {
|
||||
String xContent = " {\n"
|
||||
+ " \"size\": 8,\n"
|
||||
+ " \"unknown_doc_rating\": 2,\n"
|
||||
+ " \"normalize\": true\n"
|
||||
+ "}";
|
||||
XContentParser parser = XContentFactory.xContent(xContent).createParser(xContent);
|
||||
DiscountedCumulativeGainAt dcgAt = DiscountedCumulativeGainAt.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
|
||||
assertEquals(8, dcgAt.getPosition());
|
||||
DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
|
||||
assertEquals(2, dcgAt.getUnknownDocRating().intValue());
|
||||
assertEquals(true, dcgAt.getNormalize());
|
||||
}
|
||||
|
||||
public static DiscountedCumulativeGainAt createTestItem() {
|
||||
int position = randomIntBetween(0, 1000);
|
||||
public static DiscountedCumulativeGain createTestItem() {
|
||||
boolean normalize = randomBoolean();
|
||||
Integer unknownDocRating = new Integer(randomIntBetween(0, 1000));
|
||||
|
||||
return new DiscountedCumulativeGainAt(position, normalize, unknownDocRating);
|
||||
return new DiscountedCumulativeGain(normalize, unknownDocRating);
|
||||
}
|
||||
public void testXContentRoundtrip() throws IOException {
|
||||
DiscountedCumulativeGainAt testItem = createTestItem();
|
||||
DiscountedCumulativeGain testItem = createTestItem();
|
||||
XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem);
|
||||
itemParser.nextToken();
|
||||
itemParser.nextToken();
|
||||
DiscountedCumulativeGainAt parsedItem = DiscountedCumulativeGainAt.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
|
||||
DiscountedCumulativeGain parsedItem = DiscountedCumulativeGain.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
|
||||
assertNotSame(testItem, parsedItem);
|
||||
assertEquals(testItem, parsedItem);
|
||||
assertEquals(testItem.hashCode(), parsedItem.hashCode());
|
|
@ -93,7 +93,7 @@ public class RankEvalSpecTests extends ESTestCase {
|
|||
if (randomBoolean()) {
|
||||
metric = PrecisionTests.createTestItem();
|
||||
} else {
|
||||
metric = DiscountedCumulativeGainAtTests.createTestItem();
|
||||
metric = DiscountedCumulativeGainTests.createTestItem();
|
||||
}
|
||||
|
||||
RankEvalSpec testItem = new RankEvalSpec(specs, metric);
|
||||
|
|
|
@ -34,7 +34,6 @@ import java.util.ArrayList;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
public class ReciprocalRankTests extends ESTestCase {
|
||||
|
||||
|
@ -98,7 +97,7 @@ public class ReciprocalRankTests extends ESTestCase {
|
|||
* e.g. we set it to 2 here and expect dics 0-2 to be not relevant, so first relevant doc has
|
||||
* third ranking position, so RR should be 1/3
|
||||
*/
|
||||
public void testPrecisionAtFiveRelevanceThreshold() throws IOException, InterruptedException, ExecutionException {
|
||||
public void testPrecisionAtFiveRelevanceThreshold() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(new RatedDocument("test", "testtype", "0", 0));
|
||||
rated.add(new RatedDocument("test", "testtype", "1", 1));
|
||||
|
|
|
@ -60,7 +60,7 @@
|
|||
{"_index" : "foo", "_type" : "bar", "_id" : "doc6", "rating": 2}]
|
||||
}
|
||||
],
|
||||
"metric" : { "dcg_at_n": { "size": 6}}
|
||||
"metric" : { "dcg": {}}
|
||||
}
|
||||
|
||||
- match: {rank_eval.quality_level: 13.84826362927298}
|
||||
|
@ -85,7 +85,7 @@
|
|||
{"_index" : "foo", "_type" : "bar", "_id" : "doc6", "rating": 2}]
|
||||
},
|
||||
],
|
||||
"metric" : { "dcg_at_n": { "size": 6}}
|
||||
"metric" : { "dcg": { }}
|
||||
}
|
||||
|
||||
- match: {rank_eval.quality_level: 10.29967439154499}
|
||||
|
@ -121,7 +121,7 @@
|
|||
{"_index" : "foo", "_type" : "bar", "_id" : "doc6", "rating": 2}]
|
||||
},
|
||||
],
|
||||
"metric" : { "dcg_at_n": { "size": 6}}
|
||||
"metric" : { "dcg": { }}
|
||||
}
|
||||
|
||||
- match: {rank_eval.quality_level: 12.073969010408984}
|
||||
|
|
Loading…
Reference in New Issue