Moving averaging of partial evaluation results to RankedListQualityMetric

For the two current metrics Prec@ and reciprocal rank we currently average the
partial results in the transport action. If other metric later need a different
behaviour or want to parametrize this, this operation should be part of the
metric itself, so this change moves it there. Also removing on of the two test
packages, main code is also in one package only.
This commit is contained in:
Christoph Büscher 2016-07-29 14:49:33 +02:00
parent 0fb7dd9054
commit d71dc205fa
7 changed files with 44 additions and 17 deletions

View File

@ -111,7 +111,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
int good = 0;
int bad = 0;
Collection<String> unknownDocIds = new ArrayList<String>();
Collection<String> unknownDocIds = new ArrayList<>();
for (int i = 0; (i < n && i < hits.length); i++) {
String id = hits[i].getId();
if (relevantDocIds.contains(id)) {
@ -122,9 +122,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
unknownDocIds.add(id);
}
}
double precision = (double) good / (good + bad);
return new EvalQueryQuality(precision, unknownDocIds);
}

View File

@ -28,6 +28,7 @@ import org.elasticsearch.search.SearchHit;
import java.io.IOException;
import java.util.List;
import java.util.Vector;
/**
* Classes implementing this interface provide a means to compute the quality of a result list
@ -71,4 +72,8 @@ public abstract class RankedListQualityMetric implements NamedWriteable {
}
return rc;
}
double combine(Vector<EvalQueryQuality> partialResults) {
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
}
}

View File

@ -44,6 +44,7 @@ import org.elasticsearch.transport.TransportService;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Vector;
/**
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
@ -85,8 +86,9 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
RankedListQualityMetric metric = qualityTask.getEvaluator();
double qualitySum = 0;
Map<String, Collection<String>> unknownDocs = new HashMap<String, Collection<String>>();
Map<String, Collection<String>> unknownDocs = new HashMap<>();
Collection<QuerySpec> specifications = qualityTask.getSpecifications();
Vector<EvalQueryQuality> partialResults = new Vector<>(specifications.size());
for (QuerySpec spec : specifications) {
SearchSourceBuilder specRequest = spec.getTestRequest();
String[] indices = new String[spec.getIndices().size()];
@ -101,13 +103,14 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
ActionFuture<SearchResponse> searchResponse = transportSearchAction.execute(templatedRequest);
SearchHits hits = searchResponse.actionGet().getHits();
EvalQueryQuality intentQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
qualitySum += intentQuality.getQualityLevel();
unknownDocs.put(spec.getSpecId(), intentQuality.getUnknownDocs());
EvalQueryQuality queryQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
partialResults.addElement(queryQuality);
unknownDocs.put(spec.getSpecId(), queryQuality.getUnknownDocs());
}
RankEvalResponse response = new RankEvalResponse();
// TODO move averaging to actual metric, also add other statistics
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), qualitySum / specifications.size(), unknownDocs);
// TODO add other statistics like micro/macro avg?
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), metric.combine(partialResults), unknownDocs);
response.setRankEvalResult(result);
listener.onResponse(response);
}

View File

@ -32,8 +32,11 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.ExecutionException;
import static java.util.Collections.emptyList;
public class PrecisionAtNTests extends ESTestCase {
public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException {
@ -66,4 +69,13 @@ public class PrecisionAtNTests extends ESTestCase {
PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
assertEquals(10, precicionAt.getN());
}
public void testCombine() {
PrecisionAtN metric = new PrecisionAtN();
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
partialResults.add(new EvalQueryQuality(0.1, emptyList()));
partialResults.add(new EvalQueryQuality(0.2, emptyList()));
partialResults.add(new EvalQueryQuality(0.6, emptyList()));
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
}
}

View File

@ -17,7 +17,7 @@
* under the License.
*/
package org.elasticsearch.action.quality;
package org.elasticsearch.index.rankeval;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.rankeval.PrecisionAtN;

View File

@ -17,7 +17,7 @@
* under the License.
*/
package org.elasticsearch.action.quality;
package org.elasticsearch.index.rankeval;
import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
@ -28,8 +28,8 @@ import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException;
import java.io.IOException;
public class RankEvalRestIT extends ESClientYamlSuiteTestCase {
public RankEvalRestIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
public class RankEvalYamlIT extends ESClientYamlSuiteTestCase {
public RankEvalYamlIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
super(testCandidate);
}

View File

@ -17,13 +17,10 @@
* under the License.
*/
package org.elasticsearch.action.quality;
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.index.rankeval.EvalQueryQuality;
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
import org.elasticsearch.index.rankeval.RatedDocument;
import org.elasticsearch.index.rankeval.ReciprocalRank;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.internal.InternalSearchHit;
import org.elasticsearch.test.ESTestCase;
@ -31,6 +28,9 @@ import org.elasticsearch.test.ESTestCase;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Vector;
import static java.util.Collections.emptyList;
public class ReciprocalRankTests extends ESTestCase {
@ -86,6 +86,15 @@ public class ReciprocalRankTests extends ESTestCase {
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
}
public void testCombine() {
ReciprocalRank reciprocalRank = new ReciprocalRank();
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
partialResults.add(new EvalQueryQuality(0.5, emptyList()));
partialResults.add(new EvalQueryQuality(1.0, emptyList()));
partialResults.add(new EvalQueryQuality(0.75, emptyList()));
assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE);
}
public void testEvaluationNoRelevantInResults() {
ReciprocalRank reciprocalRank = new ReciprocalRank();
SearchHit[] hits = new SearchHit[10];