From d71dc205fad6abfe76bded043103a94034bbc4ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Fri, 29 Jul 2016 14:49:33 +0200 Subject: [PATCH] Moving averaging of partial evaluation results to RankedListQualityMetric For the two current metrics Prec@ and reciprocal rank we currently average the partial results in the transport action. If other metric later need a different behaviour or want to parametrize this, this operation should be part of the metric itself, so this change moves it there. Also removing on of the two test packages, main code is also in one package only. --- .../index/rankeval/PrecisionAtN.java | 4 +--- .../index/rankeval/RankedListQualityMetric.java | 5 +++++ .../index/rankeval/TransportRankEvalAction.java | 15 +++++++++------ .../index/rankeval/PrecisionAtNTests.java | 12 ++++++++++++ .../rankeval}/RankEvalRequestTests.java | 2 +- .../rankeval/RankEvalYamlIT.java} | 6 +++--- .../rankeval}/ReciprocalRankTests.java | 17 +++++++++++++---- 7 files changed, 44 insertions(+), 17 deletions(-) rename modules/rank-eval/src/test/java/org/elasticsearch/{action/quality => index/rankeval}/RankEvalRequestTests.java (99%) rename modules/rank-eval/src/test/java/org/elasticsearch/{action/quality/RankEvalRestIT.java => index/rankeval/RankEvalYamlIT.java} (89%) rename modules/rank-eval/src/test/java/org/elasticsearch/{action/quality => index/rankeval}/ReciprocalRankTests.java (87%) diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java index 11101826ecb..015206692db 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java @@ -111,7 +111,7 @@ public class PrecisionAtN extends RankedListQualityMetric { int good = 0; int bad = 0; - Collection unknownDocIds = new ArrayList(); + Collection unknownDocIds = new ArrayList<>(); for (int i = 0; (i < n && i < hits.length); i++) { String id = hits[i].getId(); if (relevantDocIds.contains(id)) { @@ -122,9 +122,7 @@ public class PrecisionAtN extends RankedListQualityMetric { unknownDocIds.add(id); } } - double precision = (double) good / (good + bad); - return new EvalQueryQuality(precision, unknownDocIds); } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java index 6b39afa522f..1bfd6e516f5 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java @@ -28,6 +28,7 @@ import org.elasticsearch.search.SearchHit; import java.io.IOException; import java.util.List; +import java.util.Vector; /** * Classes implementing this interface provide a means to compute the quality of a result list @@ -71,4 +72,8 @@ public abstract class RankedListQualityMetric implements NamedWriteable { } return rc; } + + double combine(Vector partialResults) { + return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size(); + } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java index ee01eb6f76c..ee978b7823b 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java @@ -44,6 +44,7 @@ import org.elasticsearch.transport.TransportService; import java.util.Collection; import java.util.HashMap; import java.util.Map; +import java.util.Vector; /** * Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of @@ -85,8 +86,9 @@ public class TransportRankEvalAction extends HandledTransportAction> unknownDocs = new HashMap>(); + Map> unknownDocs = new HashMap<>(); Collection specifications = qualityTask.getSpecifications(); + Vector partialResults = new Vector<>(specifications.size()); for (QuerySpec spec : specifications) { SearchSourceBuilder specRequest = spec.getTestRequest(); String[] indices = new String[spec.getIndices().size()]; @@ -101,13 +103,14 @@ public class TransportRankEvalAction extends HandledTransportAction searchResponse = transportSearchAction.execute(templatedRequest); SearchHits hits = searchResponse.actionGet().getHits(); - EvalQueryQuality intentQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs()); - qualitySum += intentQuality.getQualityLevel(); - unknownDocs.put(spec.getSpecId(), intentQuality.getUnknownDocs()); + EvalQueryQuality queryQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs()); + partialResults.addElement(queryQuality); + unknownDocs.put(spec.getSpecId(), queryQuality.getUnknownDocs()); } + RankEvalResponse response = new RankEvalResponse(); - // TODO move averaging to actual metric, also add other statistics - RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), qualitySum / specifications.size(), unknownDocs); + // TODO add other statistics like micro/macro avg? + RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), metric.combine(partialResults), unknownDocs); response.setRankEvalResult(result); listener.onResponse(response); } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtNTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtNTests.java index c123d5bbf8f..1078c010f56 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtNTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtNTests.java @@ -32,8 +32,11 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Vector; import java.util.concurrent.ExecutionException; +import static java.util.Collections.emptyList; + public class PrecisionAtNTests extends ESTestCase { public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException { @@ -66,4 +69,13 @@ public class PrecisionAtNTests extends ESTestCase { PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT); assertEquals(10, precicionAt.getN()); } + + public void testCombine() { + PrecisionAtN metric = new PrecisionAtN(); + Vector partialResults = new Vector<>(3); + partialResults.add(new EvalQueryQuality(0.1, emptyList())); + partialResults.add(new EvalQueryQuality(0.2, emptyList())); + partialResults.add(new EvalQueryQuality(0.6, emptyList())); + assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE); + } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRequestTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java similarity index 99% rename from modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRequestTests.java rename to modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java index 79df86f6e56..bb0b993fbdc 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRequestTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java @@ -17,7 +17,7 @@ * under the License. */ -package org.elasticsearch.action.quality; +package org.elasticsearch.index.rankeval; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.rankeval.PrecisionAtN; diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRestIT.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalYamlIT.java similarity index 89% rename from modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRestIT.java rename to modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalYamlIT.java index dc33d6ac439..4d01d183fc7 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRestIT.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalYamlIT.java @@ -17,7 +17,7 @@ * under the License. */ -package org.elasticsearch.action.quality; +package org.elasticsearch.index.rankeval; import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; @@ -28,8 +28,8 @@ import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException; import java.io.IOException; -public class RankEvalRestIT extends ESClientYamlSuiteTestCase { - public RankEvalRestIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { +public class RankEvalYamlIT extends ESClientYamlSuiteTestCase { + public RankEvalYamlIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { super(testCandidate); } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/ReciprocalRankTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java similarity index 87% rename from modules/rank-eval/src/test/java/org/elasticsearch/action/quality/ReciprocalRankTests.java rename to modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java index d51b8074757..12dd808cff7 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/ReciprocalRankTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java @@ -17,13 +17,10 @@ * under the License. */ -package org.elasticsearch.action.quality; +package org.elasticsearch.index.rankeval; import org.elasticsearch.common.text.Text; -import org.elasticsearch.index.rankeval.EvalQueryQuality; import org.elasticsearch.index.rankeval.PrecisionAtN.Rating; -import org.elasticsearch.index.rankeval.RatedDocument; -import org.elasticsearch.index.rankeval.ReciprocalRank; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.internal.InternalSearchHit; import org.elasticsearch.test.ESTestCase; @@ -31,6 +28,9 @@ import org.elasticsearch.test.ESTestCase; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Vector; + +import static java.util.Collections.emptyList; public class ReciprocalRankTests extends ESTestCase { @@ -86,6 +86,15 @@ public class ReciprocalRankTests extends ESTestCase { assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE); } + public void testCombine() { + ReciprocalRank reciprocalRank = new ReciprocalRank(); + Vector partialResults = new Vector<>(3); + partialResults.add(new EvalQueryQuality(0.5, emptyList())); + partialResults.add(new EvalQueryQuality(1.0, emptyList())); + partialResults.add(new EvalQueryQuality(0.75, emptyList())); + assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE); + } + public void testEvaluationNoRelevantInResults() { ReciprocalRank reciprocalRank = new ReciprocalRank(); SearchHit[] hits = new SearchHit[10];