Merge pull request #19686 from cbuescher/rankMetric-combine

Moving averaging of partial evaluation results to RankedListQualityMetric
This commit is contained in:
Christoph Büscher 2016-08-04 10:39:46 +02:00 committed by GitHub
commit acba915340
7 changed files with 44 additions and 17 deletions

View File

@ -111,7 +111,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
int good = 0;
int bad = 0;
Collection<String> unknownDocIds = new ArrayList<String>();
Collection<String> unknownDocIds = new ArrayList<>();
for (int i = 0; (i < n && i < hits.length); i++) {
String id = hits[i].getId();
if (relevantDocIds.contains(id)) {
@ -122,9 +122,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
unknownDocIds.add(id);
}
}
double precision = (double) good / (good + bad);
return new EvalQueryQuality(precision, unknownDocIds);
}

View File

@ -28,6 +28,7 @@ import org.elasticsearch.search.SearchHit;
import java.io.IOException;
import java.util.List;
import java.util.Vector;
/**
* Classes implementing this interface provide a means to compute the quality of a result list
@ -71,4 +72,8 @@ public abstract class RankedListQualityMetric implements NamedWriteable {
}
return rc;
}
double combine(Vector<EvalQueryQuality> partialResults) {
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
}
}

View File

@ -44,6 +44,7 @@ import org.elasticsearch.transport.TransportService;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Vector;
/**
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
@ -85,8 +86,9 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
RankedListQualityMetric metric = qualityTask.getEvaluator();
double qualitySum = 0;
Map<String, Collection<String>> unknownDocs = new HashMap<String, Collection<String>>();
Map<String, Collection<String>> unknownDocs = new HashMap<>();
Collection<QuerySpec> specifications = qualityTask.getSpecifications();
Vector<EvalQueryQuality> partialResults = new Vector<>(specifications.size());
for (QuerySpec spec : specifications) {
SearchSourceBuilder specRequest = spec.getTestRequest();
String[] indices = new String[spec.getIndices().size()];
@ -101,13 +103,14 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
ActionFuture<SearchResponse> searchResponse = transportSearchAction.execute(templatedRequest);
SearchHits hits = searchResponse.actionGet().getHits();
EvalQueryQuality intentQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
qualitySum += intentQuality.getQualityLevel();
unknownDocs.put(spec.getSpecId(), intentQuality.getUnknownDocs());
EvalQueryQuality queryQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
partialResults.addElement(queryQuality);
unknownDocs.put(spec.getSpecId(), queryQuality.getUnknownDocs());
}
RankEvalResponse response = new RankEvalResponse();
// TODO move averaging to actual metric, also add other statistics
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), qualitySum / specifications.size(), unknownDocs);
// TODO add other statistics like micro/macro avg?
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), metric.combine(partialResults), unknownDocs);
response.setRankEvalResult(result);
listener.onResponse(response);
}

View File

@ -32,8 +32,11 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.ExecutionException;
import static java.util.Collections.emptyList;
public class PrecisionAtNTests extends ESTestCase {
public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException {
@ -66,4 +69,13 @@ public class PrecisionAtNTests extends ESTestCase {
PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
assertEquals(10, precicionAt.getN());
}
public void testCombine() {
PrecisionAtN metric = new PrecisionAtN();
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
partialResults.add(new EvalQueryQuality(0.1, emptyList()));
partialResults.add(new EvalQueryQuality(0.2, emptyList()));
partialResults.add(new EvalQueryQuality(0.6, emptyList()));
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
}
}

View File

@ -17,7 +17,7 @@
* under the License.
*/
package org.elasticsearch.action.quality;
package org.elasticsearch.index.rankeval;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.rankeval.PrecisionAtN;

View File

@ -17,7 +17,7 @@
* under the License.
*/
package org.elasticsearch.action.quality;
package org.elasticsearch.index.rankeval;
import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
@ -28,8 +28,8 @@ import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException;
import java.io.IOException;
public class RankEvalRestIT extends ESClientYamlSuiteTestCase {
public RankEvalRestIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
public class RankEvalYamlIT extends ESClientYamlSuiteTestCase {
public RankEvalYamlIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
super(testCandidate);
}

View File

@ -17,13 +17,10 @@
* under the License.
*/
package org.elasticsearch.action.quality;
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.index.rankeval.EvalQueryQuality;
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
import org.elasticsearch.index.rankeval.RatedDocument;
import org.elasticsearch.index.rankeval.ReciprocalRank;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.internal.InternalSearchHit;
import org.elasticsearch.test.ESTestCase;
@ -31,6 +28,9 @@ import org.elasticsearch.test.ESTestCase;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Vector;
import static java.util.Collections.emptyList;
public class ReciprocalRankTests extends ESTestCase {
@ -86,6 +86,15 @@ public class ReciprocalRankTests extends ESTestCase {
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
}
public void testCombine() {
ReciprocalRank reciprocalRank = new ReciprocalRank();
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
partialResults.add(new EvalQueryQuality(0.5, emptyList()));
partialResults.add(new EvalQueryQuality(1.0, emptyList()));
partialResults.add(new EvalQueryQuality(0.75, emptyList()));
assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE);
}
public void testEvaluationNoRelevantInResults() {
ReciprocalRank reciprocalRank = new ReciprocalRank();
SearchHit[] hits = new SearchHit[10];