RankEval: Add optional parameter to ignore unlabeled documents in Precision metric

Our current default behaviour to ignore unrated documents when calculating the
precision seems a bit counter intuitive. Instead we should treat those documents
as "irrelevant" by default and provide an optional parameter to ignore those
documents if that is the behaviour the user wants.
This commit is contained in:
Christoph Büscher 2016-11-08 15:36:59 +01:00
parent 4718f000df
commit 4e5c868709
5 changed files with 78 additions and 36 deletions

View File

@ -45,29 +45,40 @@ import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsW
*/
public class Precision implements RankedListQualityMetric {
/** ratings equal or above this value will be considered relevant. */
private int relevantRatingThreshhold = 1;
public static final String NAME = "precision";
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled");
private static final ObjectParser<Precision, ParseFieldMatcherSupplier> PARSER = new ObjectParser<>(NAME, Precision::new);
/**
* This setting controls how unlabeled documents in the search hits are
* treated. Set to 'true', unlabeled documents are ignored and neither count
* as true or false positives. Set to 'false', they are treated as false positives.
*/
private boolean ignoreUnlabeled = false;
/** ratings equal or above this value will be considered relevant. */
private int relevantRatingThreshhold = 1;
public Precision() {
// needed for supplier in parser
}
static {
PARSER.declareInt(Precision::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD);
PARSER.declareBoolean(Precision::setIgnoreUnlabeled, IGNORE_UNLABELED_FIELD);
}
public Precision(StreamInput in) throws IOException {
relevantRatingThreshhold = in.readOptionalVInt();
ignoreUnlabeled = in.readOptionalBoolean();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalVInt(relevantRatingThreshhold);
out.writeOptionalBoolean(ignoreUnlabeled);
}
@Override
@ -90,6 +101,20 @@ public class Precision implements RankedListQualityMetric {
return relevantRatingThreshhold ;
}
/**
* Sets the 'ìgnore_unlabeled' parameter
* */
public void setIgnoreUnlabeled(boolean ignoreUnlabeled) {
this.ignoreUnlabeled = ignoreUnlabeled;
}
/**
* Gets the 'ìgnore_unlabeled' parameter
* */
public boolean getIgnoreUnlabeled() {
return ignoreUnlabeled;
}
public static Precision fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) {
return PARSER.apply(parser, matcher);
}
@ -110,6 +135,8 @@ public class Precision implements RankedListQualityMetric {
} else {
falsePositives++;
}
} else if (ignoreUnlabeled == false) {
falsePositives++;
}
}
double precision = 0.0;
@ -122,35 +149,12 @@ public class Precision implements RankedListQualityMetric {
return evalQueryQuality;
}
// TODO add abstraction that also works for other metrics
public enum Rating {
IRRELEVANT, RELEVANT;
}
/**
* Needed to get the enum accross serialisation boundaries.
* */
public static class RatingMapping {
public static Integer mapFrom(Rating rating) {
if (Rating.RELEVANT.equals(rating)) {
return 1;
}
return 0;
}
public static Rating mapTo(Integer rating) {
if (rating == 1) {
return Rating.RELEVANT;
}
return Rating.IRRELEVANT;
}
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startObject(NAME);
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.ignoreUnlabeled);
builder.endObject();
builder.endObject();
return builder;
@ -165,12 +169,13 @@ public class Precision implements RankedListQualityMetric {
return false;
}
Precision other = (Precision) obj;
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold);
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold) &&
Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled);
}
@Override
public final int hashCode() {
return Objects.hash(relevantRatingThreshhold);
return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled);
}
public static class Breakdown implements MetricDetails {

View File

@ -24,7 +24,6 @@ import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.rankeval.Precision.Rating;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.InternalSearchHit;
@ -32,6 +31,7 @@ import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Vector;
@ -106,6 +106,29 @@ public class PrecisionTests extends ESTestCase {
assertEquals(3, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
}
public void testIgnoreUnlabeled() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(new RatedDocument("test", "testtype", "0", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "testtype", "1", Rating.RELEVANT.ordinal()));
// add an unlabeled search hit
SearchHit[] searchHits = Arrays.copyOf(toSearchHits(rated, "test", "testtype"), 3);
searchHits[2] = new InternalSearchHit(2, "2", new Text("testtype"), Collections.emptyMap());
((InternalSearchHit)searchHits[2]).shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0));
EvalQueryQuality evaluated = (new Precision()).evaluate("id", searchHits, rated);
assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
// also try with setting `ignore_unlabeled`
Precision prec = new Precision();
prec.setIgnoreUnlabeled(true);
evaluated = prec.evaluate("id", searchHits, rated);
assertEquals((double) 2 / 2, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
}
public void testNoRatedDocs() throws Exception {
InternalSearchHit[] hits = new InternalSearchHit[5];
for (int i = 0; i < 5; i++) {
@ -115,6 +138,14 @@ public class PrecisionTests extends ESTestCase {
EvalQueryQuality evaluated = (new Precision()).evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
// also try with setting `ignore_unlabeled`
Precision prec = new Precision();
prec.setIgnoreUnlabeled(true);
evaluated = prec.evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
}
@ -141,6 +172,7 @@ public class PrecisionTests extends ESTestCase {
if (randomBoolean()) {
precision.setRelevantRatingThreshhold(randomIntBetween(0, 10));
}
precision.setIgnoreUnlabeled(randomBoolean());
return precision;
}
@ -163,4 +195,8 @@ public class PrecisionTests extends ESTestCase {
}
return hits;
}
public enum Rating {
IRRELEVANT, RELEVANT;
}
}

View File

@ -22,13 +22,12 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.rankeval.Precision.Rating;
import org.elasticsearch.index.rankeval.PrecisionTests.Rating;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@ -70,7 +69,7 @@ public class RankEvalRequestTests extends ESIntegTestCase {
refresh();
}
public void testPrecisionAtRequest() throws IOException {
public void testPrecisionAtRequest() {
List<String> indices = Arrays.asList(new String[] { "test" });
List<String> types = Arrays.asList(new String[] { "testtype" });
@ -84,7 +83,9 @@ public class RankEvalRequestTests extends ESIntegTestCase {
berlinRequest.setSummaryFields(Arrays.asList(new String[]{ "text", "title" }));
specifications.add(berlinRequest);
RankEvalSpec task = new RankEvalSpec(specifications, new Precision());
Precision metric = new Precision();
metric.setIgnoreUnlabeled(true);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);

View File

@ -23,7 +23,7 @@ import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.rankeval.Precision.Rating;
import org.elasticsearch.index.rankeval.PrecisionTests.Rating;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.InternalSearchHit;

View File

@ -56,7 +56,7 @@
"ratings": [{"_index": "foo", "_type": "bar", "_id": "doc1", "rating": 1}]
}
],
"metric" : { "precision": { }}
"metric" : { "precision": { "ignore_unlabeled" : true }}
}
- match: { rank_eval.quality_level: 1}