Add k parameter to PrecisionAtK metric (#27569)
This commit is contained in:
parent
1352b7c6ea
commit
7bfb273763
|
@ -89,4 +89,13 @@ public interface EvaluationMetric extends ToXContent, NamedWriteable {
|
|||
default double combine(Collection<EvalQueryQuality> partialResults) {
|
||||
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Metrics can define a size of the search hits windows they want to retrieve by overwriting
|
||||
* this method. The default implementation returns an empty optional.
|
||||
* @return the number of search hits this metrics requests
|
||||
*/
|
||||
default Optional<Integer> forcedSearchSize() {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,6 +45,7 @@ import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRati
|
|||
* relevant_rating_threshold` parameter.<br>
|
||||
* The `ignore_unlabeled` parameter (default to false) controls if unrated
|
||||
* documents should be ignored.
|
||||
* The `k` parameter (defaults to 10) controls the search window size.
|
||||
*/
|
||||
public class PrecisionAtK implements EvaluationMetric {
|
||||
|
||||
|
@ -52,9 +53,13 @@ public class PrecisionAtK implements EvaluationMetric {
|
|||
|
||||
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
|
||||
private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled");
|
||||
private static final ParseField K_FIELD = new ParseField("k");
|
||||
|
||||
private static final int DEFAULT_K = 10;
|
||||
|
||||
private final boolean ignoreUnlabeled;
|
||||
private final int relevantRatingThreshhold;
|
||||
private final int k;
|
||||
|
||||
/**
|
||||
* Metric implementing Precision@K.
|
||||
|
@ -65,42 +70,52 @@ public class PrecisionAtK implements EvaluationMetric {
|
|||
* Set to 'true', unlabeled documents are ignored and neither count
|
||||
* as true or false positives. Set to 'false', they are treated as
|
||||
* false positives.
|
||||
* @param k
|
||||
* controls the window size for the search results the metric takes into account
|
||||
*/
|
||||
public PrecisionAtK(int threshold, boolean ignoreUnlabeled) {
|
||||
public PrecisionAtK(int threshold, boolean ignoreUnlabeled, int k) {
|
||||
if (threshold < 0) {
|
||||
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
|
||||
}
|
||||
if (k <= 0) {
|
||||
throw new IllegalArgumentException("Window size k must be positive.");
|
||||
}
|
||||
this.relevantRatingThreshhold = threshold;
|
||||
this.ignoreUnlabeled = ignoreUnlabeled;
|
||||
this.k = k;
|
||||
}
|
||||
|
||||
public PrecisionAtK() {
|
||||
this(1, false);
|
||||
this(1, false, DEFAULT_K);
|
||||
}
|
||||
|
||||
|
||||
private static final ConstructingObjectParser<PrecisionAtK, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
args -> {
|
||||
Integer threshHold = (Integer) args[0];
|
||||
Boolean ignoreUnlabeled = (Boolean) args[1];
|
||||
Integer k = (Integer) args[2];
|
||||
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
|
||||
ignoreUnlabeled == null ? false : ignoreUnlabeled);
|
||||
ignoreUnlabeled == null ? false : ignoreUnlabeled,
|
||||
k == null ? 10 : k);
|
||||
});
|
||||
|
||||
static {
|
||||
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
|
||||
PARSER.declareBoolean(optionalConstructorArg(), IGNORE_UNLABELED_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
|
||||
}
|
||||
|
||||
PrecisionAtK(StreamInput in) throws IOException {
|
||||
relevantRatingThreshhold = in.readVInt();
|
||||
ignoreUnlabeled = in.readBoolean();
|
||||
k = in.readVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(relevantRatingThreshhold);
|
||||
out.writeBoolean(ignoreUnlabeled);
|
||||
out.writeVInt(k);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -123,6 +138,11 @@ public class PrecisionAtK implements EvaluationMetric {
|
|||
return ignoreUnlabeled;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<Integer> forcedSearchSize() {
|
||||
return Optional.of(k);
|
||||
}
|
||||
|
||||
public static PrecisionAtK fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
@ -167,6 +187,7 @@ public class PrecisionAtK implements EvaluationMetric {
|
|||
builder.startObject(NAME);
|
||||
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
|
||||
builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.ignoreUnlabeled);
|
||||
builder.field(K_FIELD.getPreferredName(), this.k);
|
||||
builder.endObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -182,12 +203,13 @@ public class PrecisionAtK implements EvaluationMetric {
|
|||
}
|
||||
PrecisionAtK other = (PrecisionAtK) obj;
|
||||
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold)
|
||||
&& Objects.equals(k, other.k)
|
||||
&& Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled);
|
||||
return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled, k);
|
||||
}
|
||||
|
||||
static class Breakdown implements MetricDetails {
|
||||
|
|
|
@ -85,6 +85,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
|||
protected void doExecute(RankEvalRequest request, ActionListener<RankEvalResponse> listener) {
|
||||
RankEvalSpec evaluationSpecification = request.getRankEvalSpec();
|
||||
List<String> indices = evaluationSpecification.getIndices();
|
||||
EvaluationMetric metric = evaluationSpecification.getMetric();
|
||||
|
||||
List<RatedRequest> ratedRequests = evaluationSpecification.getRatedRequests();
|
||||
Map<String, Exception> errors = new ConcurrentHashMap<>(ratedRequests.size());
|
||||
|
@ -113,6 +114,10 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
|||
}
|
||||
}
|
||||
|
||||
if (metric.forcedSearchSize().isPresent()) {
|
||||
ratedSearchSource.size(metric.forcedSearchSize().get());
|
||||
}
|
||||
|
||||
ratedRequestsInSearch.add(ratedRequest);
|
||||
List<String> summaryFields = ratedRequest.getSummaryFields();
|
||||
if (summaryFields.isEmpty()) {
|
||||
|
@ -123,7 +128,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
|||
msearchRequest.add(new SearchRequest(indices.toArray(new String[indices.size()]), ratedSearchSource));
|
||||
}
|
||||
assert ratedRequestsInSearch.size() == msearchRequest.requests().size();
|
||||
client.multiSearch(msearchRequest, new RankEvalActionListener(listener, evaluationSpecification.getMetric(),
|
||||
client.multiSearch(msearchRequest, new RankEvalActionListener(listener, metric,
|
||||
ratedRequestsInSearch.toArray(new RatedRequest[ratedRequestsInSearch.size()]), errors));
|
||||
}
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ public class PrecisionAtKTests extends ESTestCase {
|
|||
rated.add(createRatedDoc("test", "2", 2));
|
||||
rated.add(createRatedDoc("test", "3", 3));
|
||||
rated.add(createRatedDoc("test", "4", 4));
|
||||
PrecisionAtK precisionAtN = new PrecisionAtK(2, false);
|
||||
PrecisionAtK precisionAtN = new PrecisionAtK(2, false, 5);
|
||||
EvalQueryQuality evaluated = precisionAtN.evaluate("id", toSearchHits(rated, "test"), rated);
|
||||
assertEquals((double) 3 / 5, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
|
@ -113,7 +113,7 @@ public class PrecisionAtKTests extends ESTestCase {
|
|||
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
|
||||
|
||||
// also try with setting `ignore_unlabeled`
|
||||
PrecisionAtK prec = new PrecisionAtK(1, true);
|
||||
PrecisionAtK prec = new PrecisionAtK(1, true, 10);
|
||||
evaluated = prec.evaluate("id", searchHits, rated);
|
||||
assertEquals((double) 2 / 2, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(2, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
|
@ -132,7 +132,7 @@ public class PrecisionAtKTests extends ESTestCase {
|
|||
assertEquals(5, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
|
||||
|
||||
// also try with setting `ignore_unlabeled`
|
||||
PrecisionAtK prec = new PrecisionAtK(1, true);
|
||||
PrecisionAtK prec = new PrecisionAtK(1, true, 10);
|
||||
evaluated = prec.evaluate("id", hits, Collections.emptyList());
|
||||
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(0, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
|
@ -157,12 +157,15 @@ public class PrecisionAtKTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testInvalidRelevantThreshold() {
|
||||
PrecisionAtK prez = new PrecisionAtK();
|
||||
expectThrows(IllegalArgumentException.class, () -> new PrecisionAtK(-1, false));
|
||||
expectThrows(IllegalArgumentException.class, () -> new PrecisionAtK(-1, false, 10));
|
||||
}
|
||||
|
||||
public void testInvalidK() {
|
||||
expectThrows(IllegalArgumentException.class, () -> new PrecisionAtK(1, false, -10));
|
||||
}
|
||||
|
||||
public static PrecisionAtK createTestItem() {
|
||||
return new PrecisionAtK(randomIntBetween(0, 10), randomBoolean());
|
||||
return new PrecisionAtK(randomIntBetween(0, 10), randomBoolean(), randomIntBetween(1, 50));
|
||||
}
|
||||
|
||||
public void testXContentRoundtrip() throws IOException {
|
||||
|
@ -193,16 +196,28 @@ public class PrecisionAtKTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private static PrecisionAtK copy(PrecisionAtK original) {
|
||||
return new PrecisionAtK(original.getRelevantRatingThreshold(), original.getIgnoreUnlabeled());
|
||||
return new PrecisionAtK(original.getRelevantRatingThreshold(), original.getIgnoreUnlabeled(), original.forcedSearchSize().get());
|
||||
}
|
||||
|
||||
private static PrecisionAtK mutate(PrecisionAtK original) {
|
||||
if (randomBoolean()) {
|
||||
return new PrecisionAtK(original.getRelevantRatingThreshold(), !original.getIgnoreUnlabeled());
|
||||
} else {
|
||||
return new PrecisionAtK(randomValueOtherThan(original.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)),
|
||||
original.getIgnoreUnlabeled());
|
||||
PrecisionAtK pAtK;
|
||||
switch (randomIntBetween(0, 2)) {
|
||||
case 0:
|
||||
pAtK = new PrecisionAtK(original.getRelevantRatingThreshold(), !original.getIgnoreUnlabeled(),
|
||||
original.forcedSearchSize().get());
|
||||
break;
|
||||
case 1:
|
||||
pAtK = new PrecisionAtK(randomValueOtherThan(original.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)),
|
||||
original.getIgnoreUnlabeled(), original.forcedSearchSize().get());
|
||||
break;
|
||||
case 2:
|
||||
pAtK = new PrecisionAtK(original.getRelevantRatingThreshold(),
|
||||
original.getIgnoreUnlabeled(), original.forcedSearchSize().get() + 1);
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("The test should only allow three parameters mutated");
|
||||
}
|
||||
return pAtK;
|
||||
}
|
||||
|
||||
private static SearchHit[] toSearchHits(List<RatedDocument> rated, String index) {
|
||||
|
|
|
@ -20,10 +20,12 @@
|
|||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -68,6 +70,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
List<RatedRequest> specifications = new ArrayList<>();
|
||||
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
||||
testQuery.query(new MatchAllQueryBuilder());
|
||||
testQuery.sort(FieldSortBuilder.DOC_FIELD_NAME);
|
||||
RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query",
|
||||
createRelevant("2", "3", "4", "5"), testQuery);
|
||||
amsterdamRequest.addSummaryFields(Arrays.asList(new String[] { "text", "title" }));
|
||||
|
@ -79,7 +82,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
|
||||
specifications.add(berlinRequest);
|
||||
|
||||
PrecisionAtK metric = new PrecisionAtK(1, true);
|
||||
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
|
||||
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||
task.addIndices(indices);
|
||||
|
||||
|
@ -89,7 +92,8 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
|
||||
.actionGet();
|
||||
assertEquals(1.0, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||
double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0;
|
||||
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||
Set<Entry<String, EvalQueryQuality>> entrySet = response.getPartialResults().entrySet();
|
||||
assertEquals(2, entrySet.size());
|
||||
for (Entry<String, EvalQueryQuality> entry : entrySet) {
|
||||
|
@ -121,6 +125,18 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// test that a different window size k affects the result
|
||||
metric = new PrecisionAtK(1, false, 3);
|
||||
task = new RankEvalSpec(specifications, metric);
|
||||
task.addIndices(indices);
|
||||
|
||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
expectedPrecision = (1.0 / 3.0 + 2.0 / 3.0) / 2.0;
|
||||
assertEquals(0.5, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -146,18 +162,16 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());
|
||||
task.addIndices(indices);
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
|
||||
RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
|
||||
.actionGet();
|
||||
assertEquals(1, response.getFailures().size());
|
||||
ElasticsearchException[] rootCauses = ElasticsearchException
|
||||
.guessRootCauses(response.getFailures().get("broken_query"));
|
||||
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"",
|
||||
rootCauses[0].getCause().toString());
|
||||
try (Client client = client()) {
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client, RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
RankEvalResponse response = client.execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
assertEquals(1, response.getFailures().size());
|
||||
ElasticsearchException[] rootCauses = ElasticsearchException.guessRootCauses(response.getFailures().get("broken_query"));
|
||||
assertEquals("java.lang.NumberFormatException: For input string: \"noStringOnNumericFields\"",
|
||||
rootCauses[0].getCause().toString());
|
||||
}
|
||||
}
|
||||
|
||||
private static List<RatedDocument> createRelevant(String... docs) {
|
||||
|
|
Loading…
Reference in New Issue