Make ranking evaluation details accessible for client

Allow high level java rest client to access details of the metric
calculation by making them accessible across packages. Also renaming the
inner `Breakdown` classes of the evaluation metrics to `Detail` to
better communicate their use.
This commit is contained in:
Christoph Büscher 2018-04-10 10:14:52 +02:00
parent c12c2a6cc9
commit 7c56cc2624
8 changed files with 58 additions and 58 deletions

View File

@ -128,7 +128,7 @@ public class MeanReciprocalRank implements EvaluationMetric {
double reciprocalRank = (firstRelevant == -1) ? 0 : 1.0d / firstRelevant;
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, reciprocalRank);
evalQueryQuality.setMetricDetails(new Breakdown(firstRelevant));
evalQueryQuality.setMetricDetails(new Detail(firstRelevant));
evalQueryQuality.addHitsAndRatings(ratedHits);
return evalQueryQuality;
}
@ -181,16 +181,16 @@ public class MeanReciprocalRank implements EvaluationMetric {
return Objects.hash(relevantRatingThreshhold, k);
}
static class Breakdown implements MetricDetail {
public static final class Detail implements MetricDetail {
private final int firstRelevantRank;
private static ParseField FIRST_RELEVANT_RANK_FIELD = new ParseField("first_relevant");
Breakdown(int firstRelevantRank) {
Detail(int firstRelevantRank) {
this.firstRelevantRank = firstRelevantRank;
}
Breakdown(StreamInput in) throws IOException {
Detail(StreamInput in) throws IOException {
this.firstRelevantRank = in.readVInt();
}
@ -206,15 +206,15 @@ public class MeanReciprocalRank implements EvaluationMetric {
return builder.field(FIRST_RELEVANT_RANK_FIELD.getPreferredName(), firstRelevantRank);
}
private static final ConstructingObjectParser<Breakdown, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
return new Breakdown((Integer) args[0]);
private static final ConstructingObjectParser<Detail, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
return new Detail((Integer) args[0]);
});
static {
PARSER.declareInt(constructorArg(), FIRST_RELEVANT_RANK_FIELD);
}
public static Breakdown fromXContent(XContentParser parser) {
public static Detail fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
@ -232,24 +232,24 @@ public class MeanReciprocalRank implements EvaluationMetric {
* the ranking of the first relevant document, or -1 if no relevant document was
* found
*/
int getFirstRelevantRank() {
public int getFirstRelevantRank() {
return firstRelevantRank;
}
@Override
public final boolean equals(Object obj) {
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
MeanReciprocalRank.Breakdown other = (MeanReciprocalRank.Breakdown) obj;
MeanReciprocalRank.Detail other = (MeanReciprocalRank.Detail) obj;
return Objects.equals(firstRelevantRank, other.firstRelevantRank);
}
@Override
public final int hashCode() {
public int hashCode() {
return Objects.hash(firstRelevantRank);
}
}

View File

@ -181,7 +181,7 @@ public class PrecisionAtK implements EvaluationMetric {
}
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, precision);
evalQueryQuality.setMetricDetails(
new PrecisionAtK.Breakdown(truePositives, truePositives + falsePositives));
new PrecisionAtK.Detail(truePositives, truePositives + falsePositives));
evalQueryQuality.addHitsAndRatings(ratedSearchHits);
return evalQueryQuality;
}
@ -217,19 +217,19 @@ public class PrecisionAtK implements EvaluationMetric {
return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled, k);
}
static class Breakdown implements MetricDetail {
public static final class Detail implements MetricDetail {
private static final ParseField DOCS_RETRIEVED_FIELD = new ParseField("docs_retrieved");
private static final ParseField RELEVANT_DOCS_RETRIEVED_FIELD = new ParseField("relevant_docs_retrieved");
private int relevantRetrieved;
private int retrieved;
Breakdown(int relevantRetrieved, int retrieved) {
Detail(int relevantRetrieved, int retrieved) {
this.relevantRetrieved = relevantRetrieved;
this.retrieved = retrieved;
}
Breakdown(StreamInput in) throws IOException {
Detail(StreamInput in) throws IOException {
this.relevantRetrieved = in.readVInt();
this.retrieved = in.readVInt();
}
@ -242,8 +242,8 @@ public class PrecisionAtK implements EvaluationMetric {
return builder;
}
private static final ConstructingObjectParser<Breakdown, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
return new Breakdown((Integer) args[0], (Integer) args[1]);
private static final ConstructingObjectParser<Detail, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
return new Detail((Integer) args[0], (Integer) args[1]);
});
static {
@ -251,7 +251,7 @@ public class PrecisionAtK implements EvaluationMetric {
PARSER.declareInt(constructorArg(), DOCS_RETRIEVED_FIELD);
}
public static Breakdown fromXContent(XContentParser parser) {
public static Detail fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
@ -266,29 +266,29 @@ public class PrecisionAtK implements EvaluationMetric {
return NAME;
}
int getRelevantRetrieved() {
public int getRelevantRetrieved() {
return relevantRetrieved;
}
int getRetrieved() {
public int getRetrieved() {
return retrieved;
}
@Override
public final boolean equals(Object obj) {
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
PrecisionAtK.Breakdown other = (PrecisionAtK.Breakdown) obj;
PrecisionAtK.Detail other = (PrecisionAtK.Detail) obj;
return Objects.equals(relevantRetrieved, other.relevantRetrieved)
&& Objects.equals(retrieved, other.retrieved);
}
@Override
public final int hashCode() {
public int hashCode() {
return Objects.hash(relevantRetrieved, retrieved);
}
}

View File

@ -38,9 +38,9 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider {
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME),
DiscountedCumulativeGain::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME),
PrecisionAtK.Breakdown::fromXContent));
PrecisionAtK.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME),
MeanReciprocalRank.Breakdown::fromXContent));
MeanReciprocalRank.Detail::fromXContent));
return namedXContent;
}
}

View File

@ -60,9 +60,9 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Breakdown::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new));
namedWriteables
.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new));
.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
return namedWriteables;
}

View File

@ -69,9 +69,9 @@ public class EvalQueryQualityTests extends ESTestCase {
randomDoubleBetween(0.0, 1.0, true));
if (randomBoolean()) {
if (randomBoolean()) {
evalQueryQuality.setMetricDetails(new PrecisionAtK.Breakdown(randomIntBetween(0, 1000), randomIntBetween(0, 1000)));
evalQueryQuality.setMetricDetails(new PrecisionAtK.Detail(randomIntBetween(0, 1000), randomIntBetween(0, 1000)));
} else {
evalQueryQuality.setMetricDetails(new MeanReciprocalRank.Breakdown(randomIntBetween(0, 1000)));
evalQueryQuality.setMetricDetails(new MeanReciprocalRank.Detail(randomIntBetween(0, 1000)));
}
}
evalQueryQuality.addHitsAndRatings(ratedHits);
@ -137,7 +137,7 @@ public class EvalQueryQualityTests extends ESTestCase {
break;
case 2:
if (metricDetails == null) {
metricDetails = new PrecisionAtK.Breakdown(1, 5);
metricDetails = new PrecisionAtK.Detail(1, 5);
} else {
metricDetails = null;
}

View File

@ -96,7 +96,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
int rankAtFirstRelevant = relevantAt + 1;
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, ratedDocs);
assertEquals(1.0 / rankAtFirstRelevant, evaluation.getQualityLevel(), Double.MIN_VALUE);
assertEquals(rankAtFirstRelevant, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
assertEquals(rankAtFirstRelevant, ((MeanReciprocalRank.Detail) evaluation.getMetricDetails()).getFirstRelevantRank());
// check that if we have fewer search hits than relevant doc position,
// we don't find any result and get 0.0 quality level
@ -121,7 +121,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, ratedDocs);
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
assertEquals(relevantAt + 1, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
assertEquals(relevantAt + 1, ((MeanReciprocalRank.Detail) evaluation.getMetricDetails()).getFirstRelevantRank());
}
/**
@ -141,7 +141,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2, 10);
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated);
assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001);
assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
assertEquals(3, ((MeanReciprocalRank.Detail) evaluation.getMetricDetails()).getFirstRelevantRank());
}
public void testCombine() {
@ -165,7 +165,7 @@ public class MeanReciprocalRankTests extends ESTestCase {
SearchHit[] hits = new SearchHit[0];
EvalQueryQuality evaluated = (new MeanReciprocalRank()).evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
assertEquals(-1, ((MeanReciprocalRank.Breakdown) evaluated.getMetricDetails()).getFirstRelevantRank());
assertEquals(-1, ((MeanReciprocalRank.Detail) evaluated.getMetricDetails()).getFirstRelevantRank());
}
public void testXContentRoundtrip() throws IOException {

View File

@ -54,8 +54,8 @@ public class PrecisionAtKTests extends ESTestCase {
rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1));
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals(1, evaluated.getQualityLevel(), 0.00001);
assertEquals(1, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(1, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(1, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(1, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
}
public void testPrecisionAtFiveIgnoreOneResult() {
@ -67,8 +67,8 @@ public class PrecisionAtKTests extends ESTestCase {
rated.add(createRatedDoc("test", "4", IRRELEVANT_RATING_0));
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals((double) 4 / 5, evaluated.getQualityLevel(), 0.00001);
assertEquals(4, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(4, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
}
/**
@ -86,8 +86,8 @@ public class PrecisionAtKTests extends ESTestCase {
PrecisionAtK precisionAtN = new PrecisionAtK(2, false, 5);
EvalQueryQuality evaluated = precisionAtN.evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals((double) 3 / 5, evaluated.getQualityLevel(), 0.00001);
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(3, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
}
public void testPrecisionAtFiveCorrectIndex() {
@ -100,8 +100,8 @@ public class PrecisionAtKTests extends ESTestCase {
// the following search hits contain only the last three documents
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated.subList(2, 5), "test"), rated);
assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(2, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
}
public void testIgnoreUnlabeled() {
@ -115,15 +115,15 @@ public class PrecisionAtKTests extends ESTestCase {
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", searchHits, rated);
assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(2, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
// also try with setting `ignore_unlabeled`
PrecisionAtK prec = new PrecisionAtK(1, true, 10);
evaluated = prec.evaluate("id", searchHits, rated);
assertEquals((double) 2 / 2, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(2, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(2, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(2, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
}
public void testNoRatedDocs() throws Exception {
@ -134,23 +134,23 @@ public class PrecisionAtKTests extends ESTestCase {
}
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
assertEquals(0, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(0, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
// also try with setting `ignore_unlabeled`
PrecisionAtK prec = new PrecisionAtK(1, true, 10);
evaluated = prec.evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
assertEquals(0, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(0, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(0, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(0, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
}
public void testNoResults() throws Exception {
SearchHit[] hits = new SearchHit[0];
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
assertEquals(0, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(0, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
assertEquals(0, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(0, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved());
}
public void testParseFromXContent() throws IOException {

View File

@ -25,7 +25,7 @@ import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.rankeval.PrecisionAtK.Breakdown;
import org.elasticsearch.index.rankeval.PrecisionAtK.Detail;
import org.elasticsearch.indices.IndexClosedException;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
@ -271,7 +271,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
request.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, request).actionGet();
Breakdown details = (PrecisionAtK.Breakdown) response.getPartialResults().get("amsterdam_query").getMetricDetails();
Detail details = (PrecisionAtK.Detail) response.getPartialResults().get("amsterdam_query").getMetricDetails();
assertEquals(7, details.getRetrieved());
assertEquals(6, details.getRelevantRetrieved());
@ -280,7 +280,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
request.indicesOptions(IndicesOptions.fromParameters(null, "true", null, SearchRequest.DEFAULT_INDICES_OPTIONS));
response = client().execute(RankEvalAction.INSTANCE, request).actionGet();
details = (PrecisionAtK.Breakdown) response.getPartialResults().get("amsterdam_query").getMetricDetails();
details = (PrecisionAtK.Detail) response.getPartialResults().get("amsterdam_query").getMetricDetails();
assertEquals(6, details.getRetrieved());
assertEquals(5, details.getRelevantRetrieved());
@ -295,12 +295,12 @@ public class RankEvalRequestIT extends ESIntegTestCase {
request = new RankEvalRequest(task, new String[] { "tes*" });
request.indicesOptions(IndicesOptions.fromParameters("none", null, null, SearchRequest.DEFAULT_INDICES_OPTIONS));
response = client().execute(RankEvalAction.INSTANCE, request).actionGet();
details = (PrecisionAtK.Breakdown) response.getPartialResults().get("amsterdam_query").getMetricDetails();
details = (PrecisionAtK.Detail) response.getPartialResults().get("amsterdam_query").getMetricDetails();
assertEquals(0, details.getRetrieved());
request.indicesOptions(IndicesOptions.fromParameters("open", null, null, SearchRequest.DEFAULT_INDICES_OPTIONS));
response = client().execute(RankEvalAction.INSTANCE, request).actionGet();
details = (PrecisionAtK.Breakdown) response.getPartialResults().get("amsterdam_query").getMetricDetails();
details = (PrecisionAtK.Detail) response.getPartialResults().get("amsterdam_query").getMetricDetails();
assertEquals(6, details.getRetrieved());
assertEquals(5, details.getRelevantRetrieved());
@ -313,7 +313,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
request = new RankEvalRequest(task, new String[] { "bad*" });
request.indicesOptions(IndicesOptions.fromParameters(null, null, "true", SearchRequest.DEFAULT_INDICES_OPTIONS));
response = client().execute(RankEvalAction.INSTANCE, request).actionGet();
details = (PrecisionAtK.Breakdown) response.getPartialResults().get("amsterdam_query").getMetricDetails();
details = (PrecisionAtK.Detail) response.getPartialResults().get("amsterdam_query").getMetricDetails();
assertEquals(0, details.getRetrieved());
request.indicesOptions(IndicesOptions.fromParameters(null, null, "false", SearchRequest.DEFAULT_INDICES_OPTIONS));