Moving averaging of partial evaluation results to RankedListQualityMetric

For the two current metrics Prec@ and reciprocal rank we currently average the
partial results in the transport action. If other metric later need a different
behaviour or want to parametrize this, this operation should be part of the
metric itself, so this change moves it there. Also removing on of the two test
packages, main code is also in one package only.
This commit is contained in:
Christoph Büscher 2016-07-29 14:49:33 +02:00
parent 0fb7dd9054
commit d71dc205fa
7 changed files with 44 additions and 17 deletions

View File

@ -111,7 +111,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
int good = 0; int good = 0;
int bad = 0; int bad = 0;
Collection<String> unknownDocIds = new ArrayList<String>(); Collection<String> unknownDocIds = new ArrayList<>();
for (int i = 0; (i < n && i < hits.length); i++) { for (int i = 0; (i < n && i < hits.length); i++) {
String id = hits[i].getId(); String id = hits[i].getId();
if (relevantDocIds.contains(id)) { if (relevantDocIds.contains(id)) {
@ -122,9 +122,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
unknownDocIds.add(id); unknownDocIds.add(id);
} }
} }
double precision = (double) good / (good + bad); double precision = (double) good / (good + bad);
return new EvalQueryQuality(precision, unknownDocIds); return new EvalQueryQuality(precision, unknownDocIds);
} }

View File

@ -28,6 +28,7 @@ import org.elasticsearch.search.SearchHit;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Vector;
/** /**
* Classes implementing this interface provide a means to compute the quality of a result list * Classes implementing this interface provide a means to compute the quality of a result list
@ -71,4 +72,8 @@ public abstract class RankedListQualityMetric implements NamedWriteable {
} }
return rc; return rc;
} }
double combine(Vector<EvalQueryQuality> partialResults) {
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
}
} }

View File

@ -44,6 +44,7 @@ import org.elasticsearch.transport.TransportService;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Vector;
/** /**
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of * Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
@ -85,8 +86,9 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
RankedListQualityMetric metric = qualityTask.getEvaluator(); RankedListQualityMetric metric = qualityTask.getEvaluator();
double qualitySum = 0; double qualitySum = 0;
Map<String, Collection<String>> unknownDocs = new HashMap<String, Collection<String>>(); Map<String, Collection<String>> unknownDocs = new HashMap<>();
Collection<QuerySpec> specifications = qualityTask.getSpecifications(); Collection<QuerySpec> specifications = qualityTask.getSpecifications();
Vector<EvalQueryQuality> partialResults = new Vector<>(specifications.size());
for (QuerySpec spec : specifications) { for (QuerySpec spec : specifications) {
SearchSourceBuilder specRequest = spec.getTestRequest(); SearchSourceBuilder specRequest = spec.getTestRequest();
String[] indices = new String[spec.getIndices().size()]; String[] indices = new String[spec.getIndices().size()];
@ -101,13 +103,14 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
ActionFuture<SearchResponse> searchResponse = transportSearchAction.execute(templatedRequest); ActionFuture<SearchResponse> searchResponse = transportSearchAction.execute(templatedRequest);
SearchHits hits = searchResponse.actionGet().getHits(); SearchHits hits = searchResponse.actionGet().getHits();
EvalQueryQuality intentQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs()); EvalQueryQuality queryQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
qualitySum += intentQuality.getQualityLevel(); partialResults.addElement(queryQuality);
unknownDocs.put(spec.getSpecId(), intentQuality.getUnknownDocs()); unknownDocs.put(spec.getSpecId(), queryQuality.getUnknownDocs());
} }
RankEvalResponse response = new RankEvalResponse(); RankEvalResponse response = new RankEvalResponse();
// TODO move averaging to actual metric, also add other statistics // TODO add other statistics like micro/macro avg?
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), qualitySum / specifications.size(), unknownDocs); RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), metric.combine(partialResults), unknownDocs);
response.setRankEvalResult(result); response.setRankEvalResult(result);
listener.onResponse(response); listener.onResponse(response);
} }

View File

@ -32,8 +32,11 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Vector;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import static java.util.Collections.emptyList;
public class PrecisionAtNTests extends ESTestCase { public class PrecisionAtNTests extends ESTestCase {
public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException { public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException {
@ -66,4 +69,13 @@ public class PrecisionAtNTests extends ESTestCase {
PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT); PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
assertEquals(10, precicionAt.getN()); assertEquals(10, precicionAt.getN());
} }
public void testCombine() {
PrecisionAtN metric = new PrecisionAtN();
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
partialResults.add(new EvalQueryQuality(0.1, emptyList()));
partialResults.add(new EvalQueryQuality(0.2, emptyList()));
partialResults.add(new EvalQueryQuality(0.6, emptyList()));
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
}
} }

View File

@ -17,7 +17,7 @@
* under the License. * under the License.
*/ */
package org.elasticsearch.action.quality; package org.elasticsearch.index.rankeval;
import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.rankeval.PrecisionAtN; import org.elasticsearch.index.rankeval.PrecisionAtN;

View File

@ -17,7 +17,7 @@
* under the License. * under the License.
*/ */
package org.elasticsearch.action.quality; package org.elasticsearch.index.rankeval;
import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
@ -28,8 +28,8 @@ import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException;
import java.io.IOException; import java.io.IOException;
public class RankEvalRestIT extends ESClientYamlSuiteTestCase { public class RankEvalYamlIT extends ESClientYamlSuiteTestCase {
public RankEvalRestIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { public RankEvalYamlIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
super(testCandidate); super(testCandidate);
} }

View File

@ -17,13 +17,10 @@
* under the License. * under the License.
*/ */
package org.elasticsearch.action.quality; package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.text.Text; import org.elasticsearch.common.text.Text;
import org.elasticsearch.index.rankeval.EvalQueryQuality;
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating; import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
import org.elasticsearch.index.rankeval.RatedDocument;
import org.elasticsearch.index.rankeval.ReciprocalRank;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.internal.InternalSearchHit; import org.elasticsearch.search.internal.InternalSearchHit;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
@ -31,6 +28,9 @@ import org.elasticsearch.test.ESTestCase;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Vector;
import static java.util.Collections.emptyList;
public class ReciprocalRankTests extends ESTestCase { public class ReciprocalRankTests extends ESTestCase {
@ -86,6 +86,15 @@ public class ReciprocalRankTests extends ESTestCase {
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE); assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
} }
public void testCombine() {
ReciprocalRank reciprocalRank = new ReciprocalRank();
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
partialResults.add(new EvalQueryQuality(0.5, emptyList()));
partialResults.add(new EvalQueryQuality(1.0, emptyList()));
partialResults.add(new EvalQueryQuality(0.75, emptyList()));
assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE);
}
public void testEvaluationNoRelevantInResults() { public void testEvaluationNoRelevantInResults() {
ReciprocalRank reciprocalRank = new ReciprocalRank(); ReciprocalRank reciprocalRank = new ReciprocalRank();
SearchHit[] hits = new SearchHit[10]; SearchHit[] hits = new SearchHit[10];