Removing the 'size' parameter from precision metric

This commit is contained in:
Christoph Büscher 2016-11-03 16:45:42 +01:00
parent d013a2c8f1
commit 3855d7f721
9 changed files with 77 additions and 104 deletions

View File

@ -23,7 +23,7 @@ import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchHit;
@ -38,38 +38,36 @@ import javax.naming.directory.SearchResult;
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings;
/**
* Evaluate Precision at N, N being the number of search results to consider for precision calculation.
* Documents of unkonwn quality are ignored in the precision at n computation and returned by document id.
* By default documents with a rating equal or bigger than 1 are considered to be "relevant" for the precision
* calculation. This value can be changes using the "relevant_rating_threshold" parameter.
* */
public class PrecisionAtN implements RankedListQualityMetric {
/** Number of results to check against a given set of relevant results. */
private int n;
* Evaluate Precision of the search results. Documents without a rating are
* ignored. By default documents with a rating equal or bigger than 1 are
* considered to be "relevant" for the precision calculation. This value can be
* changes using the "relevant_rating_threshold" parameter.
*/
public class Precision implements RankedListQualityMetric {
/** ratings equal or above this value will be considered relevant. */
private int relevantRatingThreshhold = 1;
public static final String NAME = "precision_atn";
public static final String NAME = "precision";
private static final ParseField SIZE_FIELD = new ParseField("size");
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
private static final ConstructingObjectParser<PrecisionAtN, ParseFieldMatcherSupplier> PARSER = new ConstructingObjectParser<>(
"precision_at", a -> new PrecisionAtN((Integer) a[0]));
private static final ObjectParser<Precision, ParseFieldMatcherSupplier> PARSER = new ObjectParser<>(NAME, Precision::new);
static {
PARSER.declareInt(ConstructingObjectParser.constructorArg(), SIZE_FIELD);
PARSER.declareInt(PrecisionAtN::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD);
public Precision() {
// needed for supplier in parser
}
public PrecisionAtN(StreamInput in) throws IOException {
n = in.readInt();
static {
PARSER.declareInt(Precision::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD);
}
public Precision(StreamInput in) throws IOException {
relevantRatingThreshhold = in.readOptionalVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(n);
out.writeOptionalVInt(relevantRatingThreshhold);
}
@Override
@ -77,27 +75,6 @@ public class PrecisionAtN implements RankedListQualityMetric {
return NAME;
}
/**
* Initialises n with 10
* */
public PrecisionAtN() {
this.n = 10;
}
/**
* @param n number of top results to check against a given set of relevant results.
* */
public PrecisionAtN(int n) {
this.n= n;
}
/**
* Return number of search results to check for quality.
* */
public int getN() {
return n;
}
/**
* Sets the rating threshold above which ratings are considered to be "relevant" for this metric.
* */
@ -113,7 +90,7 @@ public class PrecisionAtN implements RankedListQualityMetric {
return relevantRatingThreshhold ;
}
public static PrecisionAtN fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) {
public static Precision fromXContent(XContentParser parser, ParseFieldMatcherSupplier matcher) {
return PARSER.apply(parser, matcher);
}
@ -140,7 +117,7 @@ public class PrecisionAtN implements RankedListQualityMetric {
precision = (double) truePositives / (truePositives + falsePositives);
}
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, precision);
evalQueryQuality.addMetricDetails(new PrecisionAtN.Breakdown(truePositives, truePositives + falsePositives));
evalQueryQuality.addMetricDetails(new Precision.Breakdown(truePositives, truePositives + falsePositives));
evalQueryQuality.addHitsAndRatings(ratedSearchHits);
return evalQueryQuality;
}
@ -173,7 +150,7 @@ public class PrecisionAtN implements RankedListQualityMetric {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startObject(NAME);
builder.field(SIZE_FIELD.getPreferredName(), this.n);
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
builder.endObject();
builder.endObject();
return builder;
@ -187,13 +164,13 @@ public class PrecisionAtN implements RankedListQualityMetric {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
PrecisionAtN other = (PrecisionAtN) obj;
return Objects.equals(n, other.n);
Precision other = (Precision) obj;
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold);
}
@Override
public final int hashCode() {
return Objects.hash(n);
return Objects.hash(relevantRatingThreshhold);
}
public static class Breakdown implements MetricDetails {
@ -247,7 +224,7 @@ public class PrecisionAtN implements RankedListQualityMetric {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
PrecisionAtN.Breakdown other = (PrecisionAtN.Breakdown) obj;
Precision.Breakdown other = (Precision.Breakdown) obj;
return Objects.equals(relevantRetrieved, other.relevantRetrieved) &&
Objects.equals(retrieved, other.retrieved);
}

View File

@ -50,12 +50,12 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
@Override
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, PrecisionAtN.NAME, PrecisionAtN::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, Precision.NAME, Precision::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, ReciprocalRank.NAME, ReciprocalRank::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGainAt.NAME,
DiscountedCumulativeGainAt::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, PrecisionAtN.NAME,
PrecisionAtN.Breakdown::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, Precision.NAME,
Precision.Breakdown::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, ReciprocalRank.NAME,
ReciprocalRank.Breakdown::new));
return namedWriteables;

View File

@ -64,8 +64,8 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
// TODO maybe switch to using a plugable registry later?
switch (metricName) {
case PrecisionAtN.NAME:
rc = PrecisionAtN.fromXContent(parser, context);
case Precision.NAME:
rc = Precision.fromXContent(parser, context);
break;
case ReciprocalRank.NAME:
rc = ReciprocalRank.fromXContent(parser, context);

View File

@ -44,7 +44,7 @@ public class EvalQueryQualityTests extends ESTestCase {
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(randomAsciiOfLength(10), randomDoubleBetween(0.0, 1.0, true));
if (randomBoolean()) {
// TODO randomize this
evalQueryQuality.addMetricDetails(new PrecisionAtN.Breakdown(1, 5));
evalQueryQuality.addMetricDetails(new Precision.Breakdown(1, 5));
}
evalQueryQuality.addHitsAndRatings(ratedHits);
return evalQueryQuality;
@ -78,7 +78,7 @@ public class EvalQueryQualityTests extends ESTestCase {
break;
case 2:
if (metricDetails == null) {
metricDetails = new PrecisionAtN.Breakdown(1, 5);
metricDetails = new Precision.Breakdown(1, 5);
} else {
metricDetails = null;
}

View File

@ -24,7 +24,7 @@ import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
import org.elasticsearch.index.rankeval.Precision.Rating;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.InternalSearchHit;
@ -36,15 +36,15 @@ import java.util.Collections;
import java.util.List;
import java.util.Vector;
public class PrecisionAtNTests extends ESTestCase {
public class PrecisionTests extends ESTestCase {
public void testPrecisionAtFiveCalculation() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(new RatedDocument("test", "testtype", "0", Rating.RELEVANT.ordinal()));
EvalQueryQuality evaluated = (new PrecisionAtN(5)).evaluate("id", toSearchHits(rated, "test", "testtype"), rated);
EvalQueryQuality evaluated = (new Precision()).evaluate("id", toSearchHits(rated, "test", "testtype"), rated);
assertEquals(1, evaluated.getQualityLevel(), 0.00001);
assertEquals(1, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(1, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(1, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(1, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
}
public void testPrecisionAtFiveIgnoreOneResult() {
@ -54,10 +54,10 @@ public class PrecisionAtNTests extends ESTestCase {
rated.add(new RatedDocument("test", "testtype", "2", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "testtype", "3", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "testtype", "4", Rating.IRRELEVANT.ordinal()));
EvalQueryQuality evaluated = (new PrecisionAtN(5)).evaluate("id", toSearchHits(rated, "test", "testtype"), rated);
EvalQueryQuality evaluated = (new Precision()).evaluate("id", toSearchHits(rated, "test", "testtype"), rated);
assertEquals((double) 4 / 5, evaluated.getQualityLevel(), 0.00001);
assertEquals(4, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(4, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
}
/**
@ -71,12 +71,12 @@ public class PrecisionAtNTests extends ESTestCase {
rated.add(new RatedDocument("test", "testtype", "2", 2));
rated.add(new RatedDocument("test", "testtype", "3", 3));
rated.add(new RatedDocument("test", "testtype", "4", 4));
PrecisionAtN precisionAtN = new PrecisionAtN(5);
Precision precisionAtN = new Precision();
precisionAtN.setRelevantRatingThreshhold(2);
EvalQueryQuality evaluated = precisionAtN.evaluate("id", toSearchHits(rated, "test", "testtype"), rated);
assertEquals((double) 3 / 5, evaluated.getQualityLevel(), 0.00001);
assertEquals(3, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(3, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
}
public void testPrecisionAtFiveCorrectIndex() {
@ -87,10 +87,10 @@ public class PrecisionAtNTests extends ESTestCase {
rated.add(new RatedDocument("test", "testtype", "1", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "testtype", "2", Rating.IRRELEVANT.ordinal()));
// the following search hits contain only the last three documents
EvalQueryQuality evaluated = (new PrecisionAtN(5)).evaluate("id", toSearchHits(rated.subList(2, 5), "test", "testtype"), rated);
EvalQueryQuality evaluated = (new Precision()).evaluate("id", toSearchHits(rated.subList(2, 5), "test", "testtype"), rated);
assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
}
public void testPrecisionAtFiveCorrectType() {
@ -100,33 +100,35 @@ public class PrecisionAtNTests extends ESTestCase {
rated.add(new RatedDocument("test", "testtype", "0", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "testtype", "1", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "testtype", "2", Rating.IRRELEVANT.ordinal()));
EvalQueryQuality evaluated = (new PrecisionAtN(5)).evaluate("id", toSearchHits(rated.subList(2, 5), "test", "testtype"), rated);
EvalQueryQuality evaluated = (new Precision()).evaluate("id", toSearchHits(rated.subList(2, 5), "test", "testtype"), rated);
assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(2, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
}
public void testNoRatedDocs() throws Exception {
List<RatedDocument> rated = new ArrayList<>();
EvalQueryQuality evaluated = (new PrecisionAtN(5)).evaluate("id", toSearchHits(rated, "test", "testtype"), rated);
InternalSearchHit[] hits = new InternalSearchHit[5];
for (int i = 0; i < 5; i++) {
hits[i] = new InternalSearchHit(i, i+"", new Text("type"), Collections.emptyMap());
hits[i].shard(new SearchShardTarget("testnode", new Index("index", "uuid"), 0));
}
EvalQueryQuality evaluated = (new Precision()).evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
assertEquals(0, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(0, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(0, ((Precision.Breakdown) evaluated.getMetricDetails()).getRetrieved());
}
public void testParseFromXContent() throws IOException {
String xContent = " {\n"
+ " \"size\": 10,\n"
+ " \"relevant_rating_threshold\" : 2"
+ "}";
XContentParser parser = XContentFactory.xContent(xContent).createParser(xContent);
PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
assertEquals(10, precicionAt.getN());
Precision precicionAt = Precision.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
assertEquals(2, precicionAt.getRelevantRatingThreshold());
}
public void testCombine() {
PrecisionAtN metric = new PrecisionAtN();
Precision metric = new Precision();
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
partialResults.add(new EvalQueryQuality("a", 0.1));
partialResults.add(new EvalQueryQuality("b", 0.2));
@ -134,17 +136,20 @@ public class PrecisionAtNTests extends ESTestCase {
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
}
public static PrecisionAtN createTestItem() {
int position = randomIntBetween(0, 1000);
return new PrecisionAtN(position);
public static Precision createTestItem() {
Precision precision = new Precision();
if (randomBoolean()) {
precision.setRelevantRatingThreshhold(randomIntBetween(0, 10));
}
return precision;
}
public void testXContentRoundtrip() throws IOException {
PrecisionAtN testItem = createTestItem();
Precision testItem = createTestItem();
XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem);
itemParser.nextToken();
itemParser.nextToken();
PrecisionAtN parsedItem = PrecisionAtN.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
Precision parsedItem = Precision.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());

View File

@ -22,7 +22,7 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
import org.elasticsearch.index.rankeval.Precision.Rating;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
@ -84,7 +84,7 @@ public class RankEvalRequestTests extends ESIntegTestCase {
berlinRequest.setSummaryFields(Arrays.asList(new String[]{ "text", "title" }));
specifications.add(berlinRequest);
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtN(10));
RankEvalSpec task = new RankEvalSpec(specifications, new Precision());
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
@ -140,7 +140,7 @@ public class RankEvalRequestTests extends ESIntegTestCase {
brokenQuery.query(brokenRangeQuery);
specifications.add(new RatedRequest("broken_query", brokenQuery, indices, types, createRelevant("1")));
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtN(10));
RankEvalSpec task = new RankEvalSpec(specifications, new Precision());
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);

View File

@ -23,7 +23,6 @@ import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ParseFieldRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
@ -57,7 +56,7 @@ public class RankEvalSpecTests extends ESTestCase {
* setup for the whole base test class
*/
@BeforeClass
public static void init() throws IOException {
public static void init() {
AggregatorParsers aggsParsers = new AggregatorParsers(new ParseFieldRegistry<>("aggregation"),
new ParseFieldRegistry<>("aggregation_pipes"));
searchModule = new SearchModule(Settings.EMPTY, false, emptyList());
@ -92,22 +91,19 @@ public class RankEvalSpecTests extends ESTestCase {
RankedListQualityMetric metric;
if (randomBoolean()) {
metric = PrecisionAtNTests.createTestItem();
metric = PrecisionTests.createTestItem();
} else {
metric = DiscountedCumulativeGainAtTests.createTestItem();
}
RankEvalSpec testItem = new RankEvalSpec(specs, metric);
XContentType contentType = ESTestCase.randomFrom(XContentType.values());
XContent xContent = contentType.xContent();
if (randomBoolean()) {
final Map<String, Object> params = randomBoolean() ? null : Collections.singletonMap("key", "value");
ScriptType scriptType = randomFrom(ScriptType.values());
String script;
if (scriptType == ScriptType.INLINE) {
try (XContentBuilder builder = XContentBuilder.builder(xContent)) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.startObject();
builder.field("field", randomAsciiOfLengthBetween(1, 5));
builder.endObject();
@ -122,15 +118,10 @@ public class RankEvalSpecTests extends ESTestCase {
scriptType,
randomFrom("_lang1", "_lang2", null),
params,
scriptType == ScriptType.INLINE ? xContent.type() : null));
scriptType == ScriptType.INLINE ? XContentType.JSON : null));
}
XContentBuilder builder = XContentFactory.contentBuilder(contentType);
if (ESTestCase.randomBoolean()) {
builder.prettyPrint();
}
testItem.toXContent(builder, ToXContent.EMPTY_PARAMS);
XContentBuilder shuffled = ESTestCase.shuffleXContent(builder);
XContentBuilder shuffled = ESTestCase.shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
XContentParser itemParser = XContentHelper.createParser(shuffled.bytes());
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT);

View File

@ -23,7 +23,7 @@ import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
import org.elasticsearch.index.rankeval.Precision.Rating;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.InternalSearchHit;

View File

@ -56,7 +56,7 @@
"ratings": [{"_index": "foo", "_type": "bar", "_id": "doc1", "rating": 1}]
}
],
"metric" : { "precision_atn": { "size": 10}}
"metric" : { "precision": { }}
}
- match: { rank_eval.quality_level: 1}