RankEval: Add optional parameter to ignore unlabeled documents in Precision metric
Our current default behaviour to ignore unrated documents when calculating the precision seems a bit counter intuitive. Instead we should treat those documents as "irrelevant" by default and provide an optional parameter to ignore those documents if that is the behaviour the user wants.
This commit is contained in:
parent
4718f000df
commit
4e5c868709
|
@ -45,29 +45,40 @@ import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsW
|
|||
*/
|
||||
public class Precision implements RankedListQualityMetric {
|
||||
|
||||
/** ratings equal or above this value will be considered relevant. */
|
||||
private int relevantRatingThreshhold = 1;
|
||||
|
||||
public static final String NAME = "precision";
|
||||
|
||||
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
|
||||
private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled");
|
||||
private static final ObjectParser<Precision, ParseFieldMatcherSupplier> PARSER = new ObjectParser<>(NAME, Precision::new);
|
||||
|
||||
/**
|
||||
* This setting controls how unlabeled documents in the search hits are
|
||||
* treated. Set to 'true', unlabeled documents are ignored and neither count
|
||||
* as true or false positives. Set to 'false', they are treated as false positives.
|
||||
*/
|
||||
private boolean ignoreUnlabeled = false;
|
||||
|
||||
/** ratings equal or above this value will be considered relevant. */
|
||||
private int relevantRatingThreshhold = 1;
|
||||
|
||||
public Precision() {
|
||||
// needed for supplier in parser
|
||||
}
|
||||
|
||||
static {
|
||||
PARSER.declareInt(Precision::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD);
|
||||
PARSER.declareBoolean(Precision::setIgnoreUnlabeled, IGNORE_UNLABELED_FIELD);
|
||||
}
|
||||
|
||||
public Precision(StreamInput in) throws IOException {
|
||||
relevantRatingThreshhold = in.readOptionalVInt();
|
||||
ignoreUnlabeled = in.readOptionalBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalVInt(relevantRatingThreshhold);
|
||||
out.writeOptionalBoolean(ignoreUnlabeled);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -90,6 +101,20 @@ public class Precision implements RankedListQualityMetric {
|
|||
return relevantRatingThreshhold ;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the 'ìgnore_unlabeled' parameter
|
||||
* */
|
||||
public void setIgnoreUnlabeled(boolean ignoreUnlabeled) {
|
||||
this.ignoreUnlabeled = ignoreUnlabeled;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the 'ìgnore_unlabeled' parameter
|
||||
* */
|
||||
public boolean getIgnoreUnlabeled() {
|
||||
return ignoreUnlabeled;
|
||||
}
|
||||
|
||||
public static Precision fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) {
|
||||
return PARSER.apply(parser, matcher);
|
||||
}
|
||||
|
@ -110,6 +135,8 @@ public class Precision implements RankedListQualityMetric {
|
|||
} else {
|
||||
falsePositives++;
|
||||
}
|
||||
} else if (ignoreUnlabeled == false) {
|
||||
falsePositives++;
|
||||
}
|
||||
}
|
||||
double precision = 0.0;
|
||||
|
@ -122,35 +149,12 @@ public class Precision implements RankedListQualityMetric {
|
|||
return evalQueryQuality;
|
||||
}
|
||||
|
||||
// TODO add abstraction that also works for other metrics
|
||||
public enum Rating {
|
||||
IRRELEVANT, RELEVANT;
|
||||
}
|
||||
|
||||
/**
|
||||
* Needed to get the enum accross serialisation boundaries.
|
||||
* */
|
||||
public static class RatingMapping {
|
||||
public static Integer mapFrom(Rating rating) {
|
||||
if (Rating.RELEVANT.equals(rating)) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public static Rating mapTo(Integer rating) {
|
||||
if (rating == 1) {
|
||||
return Rating.RELEVANT;
|
||||
}
|
||||
return Rating.IRRELEVANT;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.startObject(NAME);
|
||||
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
|
||||
builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.ignoreUnlabeled);
|
||||
builder.endObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -165,12 +169,13 @@ public class Precision implements RankedListQualityMetric {
|
|||
return false;
|
||||
}
|
||||
Precision other = (Precision) obj;
|
||||
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold);
|
||||
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold) &&
|
||||
Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(relevantRatingThreshhold);
|
||||
return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled);
|
||||
}
|
||||
|
||||
public static class Breakdown implements MetricDetails {
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.elasticsearch.common.text.Text;
|
|||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.Index;
|
||||
import org.elasticsearch.index.rankeval.Precision.Rating;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchShardTarget;
|
||||
import org.elasticsearch.search.internal.InternalSearchHit;
|
||||
|
@ -32,6 +31,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
|
@ -106,6 +106,29 @@ public class PrecisionTests extends ESTestCase {
|
|||
assertEquals(3, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
|
||||
}
|
||||
|
||||
public void testIgnoreUnlabeled() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(new RatedDocument("test", "testtype", "0", Rating.RELEVANT.ordinal()));
|
||||
rated.add(new RatedDocument("test", "testtype", "1", Rating.RELEVANT.ordinal()));
|
||||
// add an unlabeled search hit
|
||||
SearchHit[] searchHits = Arrays.copyOf(toSearchHits(rated, "test", "testtype"), 3);
|
||||
searchHits[2] = new InternalSearchHit(2, "2", new Text("testtype"), Collections.emptyMap());
|
||||
((InternalSearchHit)searchHits[2]).shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0));
|
||||
|
||||
EvalQueryQuality evaluated = (new Precision()).evaluate("id", searchHits, rated);
|
||||
assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(3, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
|
||||
|
||||
// also try with setting `ignore_unlabeled`
|
||||
Precision prec = new Precision();
|
||||
prec.setIgnoreUnlabeled(true);
|
||||
evaluated = prec.evaluate("id", searchHits, rated);
|
||||
assertEquals((double) 2 / 2, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
|
||||
}
|
||||
|
||||
public void testNoRatedDocs() throws Exception {
|
||||
InternalSearchHit[] hits = new InternalSearchHit[5];
|
||||
for (int i = 0; i < 5; i++) {
|
||||
|
@ -115,6 +138,14 @@ public class PrecisionTests extends ESTestCase {
|
|||
EvalQueryQuality evaluated = (new Precision()).evaluate("id", hits, Collections.emptyList());
|
||||
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(5, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
|
||||
|
||||
// also try with setting `ignore_unlabeled`
|
||||
Precision prec = new Precision();
|
||||
prec.setIgnoreUnlabeled(true);
|
||||
evaluated = prec.evaluate("id", hits, Collections.emptyList());
|
||||
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
|
||||
}
|
||||
|
||||
|
@ -141,6 +172,7 @@ public class PrecisionTests extends ESTestCase {
|
|||
if (randomBoolean()) {
|
||||
precision.setRelevantRatingThreshhold(randomIntBetween(0, 10));
|
||||
}
|
||||
precision.setIgnoreUnlabeled(randomBoolean());
|
||||
return precision;
|
||||
}
|
||||
|
||||
|
@ -163,4 +195,8 @@ public class PrecisionTests extends ESTestCase {
|
|||
}
|
||||
return hits;
|
||||
}
|
||||
|
||||
public enum Rating {
|
||||
IRRELEVANT, RELEVANT;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,13 +22,12 @@ package org.elasticsearch.index.rankeval;
|
|||
import org.elasticsearch.action.search.SearchPhaseExecutionException;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.query.RangeQueryBuilder;
|
||||
import org.elasticsearch.index.rankeval.Precision.Rating;
|
||||
import org.elasticsearch.index.rankeval.PrecisionTests.Rating;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
|
@ -70,7 +69,7 @@ public class RankEvalRequestTests extends ESIntegTestCase {
|
|||
refresh();
|
||||
}
|
||||
|
||||
public void testPrecisionAtRequest() throws IOException {
|
||||
public void testPrecisionAtRequest() {
|
||||
List<String> indices = Arrays.asList(new String[] { "test" });
|
||||
List<String> types = Arrays.asList(new String[] { "testtype" });
|
||||
|
||||
|
@ -84,7 +83,9 @@ public class RankEvalRequestTests extends ESIntegTestCase {
|
|||
berlinRequest.setSummaryFields(Arrays.asList(new String[]{ "text", "title" }));
|
||||
specifications.add(berlinRequest);
|
||||
|
||||
RankEvalSpec task = new RankEvalSpec(specifications, new Precision());
|
||||
Precision metric = new Precision();
|
||||
metric.setIgnoreUnlabeled(true);
|
||||
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.elasticsearch.common.ParseFieldMatcher;
|
|||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.Index;
|
||||
import org.elasticsearch.index.rankeval.Precision.Rating;
|
||||
import org.elasticsearch.index.rankeval.PrecisionTests.Rating;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchShardTarget;
|
||||
import org.elasticsearch.search.internal.InternalSearchHit;
|
||||
|
|
|
@ -56,7 +56,7 @@
|
|||
"ratings": [{"_index": "foo", "_type": "bar", "_id": "doc1", "rating": 1}]
|
||||
}
|
||||
],
|
||||
"metric" : { "precision": { }}
|
||||
"metric" : { "precision": { "ignore_unlabeled" : true }}
|
||||
}
|
||||
|
||||
- match: { rank_eval.quality_level: 1}
|
||||
|
|
Loading…
Reference in New Issue