diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java index 11101826ecb..015206692db 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtN.java @@ -111,7 +111,7 @@ public class PrecisionAtN extends RankedListQualityMetric { int good = 0; int bad = 0; - Collection unknownDocIds = new ArrayList(); + Collection unknownDocIds = new ArrayList<>(); for (int i = 0; (i < n && i < hits.length); i++) { String id = hits[i].getId(); if (relevantDocIds.contains(id)) { @@ -122,9 +122,7 @@ public class PrecisionAtN extends RankedListQualityMetric { unknownDocIds.add(id); } } - double precision = (double) good / (good + bad); - return new EvalQueryQuality(precision, unknownDocIds); } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java index 6b39afa522f..1bfd6e516f5 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankedListQualityMetric.java @@ -28,6 +28,7 @@ import org.elasticsearch.search.SearchHit; import java.io.IOException; import java.util.List; +import java.util.Vector; /** * Classes implementing this interface provide a means to compute the quality of a result list @@ -71,4 +72,8 @@ public abstract class RankedListQualityMetric implements NamedWriteable { } return rc; } + + double combine(Vector partialResults) { + return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size(); + } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java index ee01eb6f76c..ee978b7823b 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java @@ -44,6 +44,7 @@ import org.elasticsearch.transport.TransportService; import java.util.Collection; import java.util.HashMap; import java.util.Map; +import java.util.Vector; /** * Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of @@ -85,8 +86,9 @@ public class TransportRankEvalAction extends HandledTransportAction> unknownDocs = new HashMap>(); + Map> unknownDocs = new HashMap<>(); Collection specifications = qualityTask.getSpecifications(); + Vector partialResults = new Vector<>(specifications.size()); for (QuerySpec spec : specifications) { SearchSourceBuilder specRequest = spec.getTestRequest(); String[] indices = new String[spec.getIndices().size()]; @@ -101,13 +103,14 @@ public class TransportRankEvalAction extends HandledTransportAction searchResponse = transportSearchAction.execute(templatedRequest); SearchHits hits = searchResponse.actionGet().getHits(); - EvalQueryQuality intentQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs()); - qualitySum += intentQuality.getQualityLevel(); - unknownDocs.put(spec.getSpecId(), intentQuality.getUnknownDocs()); + EvalQueryQuality queryQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs()); + partialResults.addElement(queryQuality); + unknownDocs.put(spec.getSpecId(), queryQuality.getUnknownDocs()); } + RankEvalResponse response = new RankEvalResponse(); - // TODO move averaging to actual metric, also add other statistics - RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), qualitySum / specifications.size(), unknownDocs); + // TODO add other statistics like micro/macro avg? + RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), metric.combine(partialResults), unknownDocs); response.setRankEvalResult(result); listener.onResponse(response); } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtNTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtNTests.java index c123d5bbf8f..1078c010f56 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtNTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtNTests.java @@ -32,8 +32,11 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Vector; import java.util.concurrent.ExecutionException; +import static java.util.Collections.emptyList; + public class PrecisionAtNTests extends ESTestCase { public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException { @@ -66,4 +69,13 @@ public class PrecisionAtNTests extends ESTestCase { PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT); assertEquals(10, precicionAt.getN()); } + + public void testCombine() { + PrecisionAtN metric = new PrecisionAtN(); + Vector partialResults = new Vector<>(3); + partialResults.add(new EvalQueryQuality(0.1, emptyList())); + partialResults.add(new EvalQueryQuality(0.2, emptyList())); + partialResults.add(new EvalQueryQuality(0.6, emptyList())); + assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE); + } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRequestTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java similarity index 99% rename from modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRequestTests.java rename to modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java index 79df86f6e56..bb0b993fbdc 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRequestTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java @@ -17,7 +17,7 @@ * under the License. */ -package org.elasticsearch.action.quality; +package org.elasticsearch.index.rankeval; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.rankeval.PrecisionAtN; diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRestIT.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalYamlIT.java similarity index 89% rename from modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRestIT.java rename to modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalYamlIT.java index dc33d6ac439..4d01d183fc7 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/RankEvalRestIT.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalYamlIT.java @@ -17,7 +17,7 @@ * under the License. */ -package org.elasticsearch.action.quality; +package org.elasticsearch.index.rankeval; import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; @@ -28,8 +28,8 @@ import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException; import java.io.IOException; -public class RankEvalRestIT extends ESClientYamlSuiteTestCase { - public RankEvalRestIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { +public class RankEvalYamlIT extends ESClientYamlSuiteTestCase { + public RankEvalYamlIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { super(testCandidate); } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/ReciprocalRankTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java similarity index 87% rename from modules/rank-eval/src/test/java/org/elasticsearch/action/quality/ReciprocalRankTests.java rename to modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java index d51b8074757..12dd808cff7 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/action/quality/ReciprocalRankTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/ReciprocalRankTests.java @@ -17,13 +17,10 @@ * under the License. */ -package org.elasticsearch.action.quality; +package org.elasticsearch.index.rankeval; import org.elasticsearch.common.text.Text; -import org.elasticsearch.index.rankeval.EvalQueryQuality; import org.elasticsearch.index.rankeval.PrecisionAtN.Rating; -import org.elasticsearch.index.rankeval.RatedDocument; -import org.elasticsearch.index.rankeval.ReciprocalRank; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.internal.InternalSearchHit; import org.elasticsearch.test.ESTestCase; @@ -31,6 +28,9 @@ import org.elasticsearch.test.ESTestCase; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Vector; + +import static java.util.Collections.emptyList; public class ReciprocalRankTests extends ESTestCase { @@ -86,6 +86,15 @@ public class ReciprocalRankTests extends ESTestCase { assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE); } + public void testCombine() { + ReciprocalRank reciprocalRank = new ReciprocalRank(); + Vector partialResults = new Vector<>(3); + partialResults.add(new EvalQueryQuality(0.5, emptyList())); + partialResults.add(new EvalQueryQuality(1.0, emptyList())); + partialResults.add(new EvalQueryQuality(0.75, emptyList())); + assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE); + } + public void testEvaluationNoRelevantInResults() { ReciprocalRank reciprocalRank = new ReciprocalRank(); SearchHit[] hits = new SearchHit[10];