Add k parameter to PrecisionAtK metric (#27569)

This commit is contained in:
Christoph Büscher 2017-11-29 15:19:16 +01:00 committed by GitHub
parent 1352b7c6ea
commit 7bfb273763
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 96 additions and 31 deletions

View File

@ -89,4 +89,13 @@ public interface EvaluationMetric extends ToXContent, NamedWriteable {
default double combine(Collection<EvalQueryQuality> partialResults) {
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
}
/**
* Metrics can define a size of the search hits windows they want to retrieve by overwriting
* this method. The default implementation returns an empty optional.
* @return the number of search hits this metrics requests
*/
default Optional<Integer> forcedSearchSize() {
return Optional.empty();
}
}

View File

@ -45,6 +45,7 @@ import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRati
* relevant_rating_threshold` parameter.<br>
* The `ignore_unlabeled` parameter (default to false) controls if unrated
* documents should be ignored.
* The `k` parameter (defaults to 10) controls the search window size.
*/
public class PrecisionAtK implements EvaluationMetric {
@ -52,9 +53,13 @@ public class PrecisionAtK implements EvaluationMetric {
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled");
private static final ParseField K_FIELD = new ParseField("k");
private static final int DEFAULT_K = 10;
private final boolean ignoreUnlabeled;
private final int relevantRatingThreshhold;
private final int k;
/**
* Metric implementing Precision@K.
@ -65,42 +70,52 @@ public class PrecisionAtK implements EvaluationMetric {
* Set to 'true', unlabeled documents are ignored and neither count
* as true or false positives. Set to 'false', they are treated as
* false positives.
* @param k
* controls the window size for the search results the metric takes into account
*/
public PrecisionAtK(int threshold, boolean ignoreUnlabeled) {
public PrecisionAtK(int threshold, boolean ignoreUnlabeled, int k) {
if (threshold < 0) {
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
}
if (k <= 0) {
throw new IllegalArgumentException("Window size k must be positive.");
}
this.relevantRatingThreshhold = threshold;
this.ignoreUnlabeled = ignoreUnlabeled;
this.k = k;
}
public PrecisionAtK() {
this(1, false);
this(1, false, DEFAULT_K);
}
private static final ConstructingObjectParser<PrecisionAtK, Void> PARSER = new ConstructingObjectParser<>(NAME,
args -> {
Integer threshHold = (Integer) args[0];
Boolean ignoreUnlabeled = (Boolean) args[1];
Integer k = (Integer) args[2];
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
ignoreUnlabeled == null ? false : ignoreUnlabeled);
ignoreUnlabeled == null ? false : ignoreUnlabeled,
k == null ? 10 : k);
});
static {
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
PARSER.declareBoolean(optionalConstructorArg(), IGNORE_UNLABELED_FIELD);
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
}
PrecisionAtK(StreamInput in) throws IOException {
relevantRatingThreshhold = in.readVInt();
ignoreUnlabeled = in.readBoolean();
k = in.readVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(relevantRatingThreshhold);
out.writeBoolean(ignoreUnlabeled);
out.writeVInt(k);
}
@Override
@ -123,6 +138,11 @@ public class PrecisionAtK implements EvaluationMetric {
return ignoreUnlabeled;
}
@Override
public Optional<Integer> forcedSearchSize() {
return Optional.of(k);
}
public static PrecisionAtK fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
@ -167,6 +187,7 @@ public class PrecisionAtK implements EvaluationMetric {
builder.startObject(NAME);
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.ignoreUnlabeled);
builder.field(K_FIELD.getPreferredName(), this.k);
builder.endObject();
builder.endObject();
return builder;
@ -182,12 +203,13 @@ public class PrecisionAtK implements EvaluationMetric {
}
PrecisionAtK other = (PrecisionAtK) obj;
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold)
&& Objects.equals(k, other.k)
&& Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled);
}
@Override
public final int hashCode() {
return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled);
return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled, k);
}
static class Breakdown implements MetricDetails {

View File

@ -85,6 +85,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
protected void doExecute(RankEvalRequest request, ActionListener<RankEvalResponse> listener) {
RankEvalSpec evaluationSpecification = request.getRankEvalSpec();
List<String> indices = evaluationSpecification.getIndices();
EvaluationMetric metric = evaluationSpecification.getMetric();
List<RatedRequest> ratedRequests = evaluationSpecification.getRatedRequests();
Map<String, Exception> errors = new ConcurrentHashMap<>(ratedRequests.size());
@ -113,6 +114,10 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
}
}
if (metric.forcedSearchSize().isPresent()) {
ratedSearchSource.size(metric.forcedSearchSize().get());
}
ratedRequestsInSearch.add(ratedRequest);
List<String> summaryFields = ratedRequest.getSummaryFields();
if (summaryFields.isEmpty()) {
@ -123,7 +128,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
msearchRequest.add(new SearchRequest(indices.toArray(new String[indices.size()]), ratedSearchSource));
}
assert ratedRequestsInSearch.size() == msearchRequest.requests().size();
client.multiSearch(msearchRequest, new RankEvalActionListener(listener, evaluationSpecification.getMetric(),
client.multiSearch(msearchRequest, new RankEvalActionListener(listener, metric,
ratedRequestsInSearch.toArray(new RatedRequest[ratedRequestsInSearch.size()]), errors));
}

View File

@ -77,7 +77,7 @@ public class PrecisionAtKTests extends ESTestCase {
rated.add(createRatedDoc("test", "2", 2));
rated.add(createRatedDoc("test", "3", 3));
rated.add(createRatedDoc("test", "4", 4));
PrecisionAtK precisionAtN = new PrecisionAtK(2, false);
PrecisionAtK precisionAtN = new PrecisionAtK(2, false, 5);
EvalQueryQuality evaluated = precisionAtN.evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals((double) 3 / 5, evaluated.getQualityLevel(), 0.00001);
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
@ -113,7 +113,7 @@ public class PrecisionAtKTests extends ESTestCase {
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
// also try with setting `ignore_unlabeled`
PrecisionAtK prec = new PrecisionAtK(1, true);
PrecisionAtK prec = new PrecisionAtK(1, true, 10);
evaluated = prec.evaluate("id", searchHits, rated);
assertEquals((double) 2 / 2, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
@ -132,7 +132,7 @@ public class PrecisionAtKTests extends ESTestCase {
assertEquals(5, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
// also try with setting `ignore_unlabeled`
PrecisionAtK prec = new PrecisionAtK(1, true);
PrecisionAtK prec = new PrecisionAtK(1, true, 10);
evaluated = prec.evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
assertEquals(0, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
@ -157,12 +157,15 @@ public class PrecisionAtKTests extends ESTestCase {
}
public void testInvalidRelevantThreshold() {
PrecisionAtK prez = new PrecisionAtK();
expectThrows(IllegalArgumentException.class, () -> new PrecisionAtK(-1, false));
expectThrows(IllegalArgumentException.class, () -> new PrecisionAtK(-1, false, 10));
}
public void testInvalidK() {
expectThrows(IllegalArgumentException.class, () -> new PrecisionAtK(1, false, -10));
}
public static PrecisionAtK createTestItem() {
return new PrecisionAtK(randomIntBetween(0, 10), randomBoolean());
return new PrecisionAtK(randomIntBetween(0, 10), randomBoolean(), randomIntBetween(1, 50));
}
public void testXContentRoundtrip() throws IOException {
@ -193,16 +196,28 @@ public class PrecisionAtKTests extends ESTestCase {
}
private static PrecisionAtK copy(PrecisionAtK original) {
return new PrecisionAtK(original.getRelevantRatingThreshold(), original.getIgnoreUnlabeled());
return new PrecisionAtK(original.getRelevantRatingThreshold(), original.getIgnoreUnlabeled(), original.forcedSearchSize().get());
}
private static PrecisionAtK mutate(PrecisionAtK original) {
if (randomBoolean()) {
return new PrecisionAtK(original.getRelevantRatingThreshold(), !original.getIgnoreUnlabeled());
} else {
return new PrecisionAtK(randomValueOtherThan(original.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)),
original.getIgnoreUnlabeled());
PrecisionAtK pAtK;
switch (randomIntBetween(0, 2)) {
case 0:
pAtK = new PrecisionAtK(original.getRelevantRatingThreshold(), !original.getIgnoreUnlabeled(),
original.forcedSearchSize().get());
break;
case 1:
pAtK = new PrecisionAtK(randomValueOtherThan(original.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)),
original.getIgnoreUnlabeled(), original.forcedSearchSize().get());
break;
case 2:
pAtK = new PrecisionAtK(original.getRelevantRatingThreshold(),
original.getIgnoreUnlabeled(), original.forcedSearchSize().get() + 1);
break;
default:
throw new IllegalStateException("The test should only allow three parameters mutated");
}
return pAtK;
}
private static SearchHit[] toSearchHits(List<RatedDocument> rated, String index) {

View File

@ -20,10 +20,12 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.client.Client;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.junit.Before;
@ -68,6 +70,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
List<RatedRequest> specifications = new ArrayList<>();
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
testQuery.sort(FieldSortBuilder.DOC_FIELD_NAME);
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query",
createRelevant("2", "3", "4", "5"), testQuery);
amsterdamRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
@ -79,7 +82,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
specifications.add(berlinRequest);
PrecisionAtK metric = new PrecisionAtK(1, true);
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
task.addIndices(indices);
@ -89,7 +92,8 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
.actionGet();
assertEquals(1.0, response.getEvaluationResult(), Double.MIN_VALUE);
double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0;
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
Set<Entry<String, EvalQueryQuality>> entrySet = response.getPartialResults().entrySet();
assertEquals(2, entrySet.size());
for (Entry<String, EvalQueryQuality> entry : entrySet) {
@ -121,6 +125,18 @@ public class RankEvalRequestIT extends ESIntegTestCase {
}
}
}
// test that a different window size k affects the result
metric = new PrecisionAtK(1, false, 3);
task = new RankEvalSpec(specifications, metric);
task.addIndices(indices);
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
expectedPrecision = (1.0 / 3.0 + 2.0 / 3.0) / 2.0;
assertEquals(0.5, response.getEvaluationResult(), Double.MIN_VALUE);
}
/**
@ -146,18 +162,16 @@ public class RankEvalRequestIT extends ESIntegTestCase {
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());
task.addIndices(indices);
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
.actionGet();
assertEquals(1, response.getFailures().size());
ElasticsearchException[] rootCauses = ElasticsearchException
.guessRootCauses(response.getFailures().get("broken_query"));
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"",
rootCauses[0].getCause().toString());
try (Client client = client()) {
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client, RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalResponse response = client.execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(1, response.getFailures().size());
ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"",
rootCauses[0].getCause().toString());
}
}
private static List<RatedDocument> createRelevant(String... docs) {