Moving averaging of partial evaluation results to RankedListQualityMetric
For the two current metrics Prec@ and reciprocal rank we currently average the partial results in the transport action. If other metric later need a different behaviour or want to parametrize this, this operation should be part of the metric itself, so this change moves it there. Also removing on of the two test packages, main code is also in one package only.
This commit is contained in:
parent
0fb7dd9054
commit
d71dc205fa
|
@ -111,7 +111,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
|
|||
|
||||
int good = 0;
|
||||
int bad = 0;
|
||||
Collection<String> unknownDocIds = new ArrayList<String>();
|
||||
Collection<String> unknownDocIds = new ArrayList<>();
|
||||
for (int i = 0; (i < n && i < hits.length); i++) {
|
||||
String id = hits[i].getId();
|
||||
if (relevantDocIds.contains(id)) {
|
||||
|
@ -122,9 +122,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
|
|||
unknownDocIds.add(id);
|
||||
}
|
||||
}
|
||||
|
||||
double precision = (double) good / (good + bad);
|
||||
|
||||
return new EvalQueryQuality(precision, unknownDocIds);
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.elasticsearch.search.SearchHit;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
|
||||
/**
|
||||
* Classes implementing this interface provide a means to compute the quality of a result list
|
||||
|
@ -71,4 +72,8 @@ public abstract class RankedListQualityMetric implements NamedWriteable {
|
|||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
double combine(Vector<EvalQueryQuality> partialResults) {
|
||||
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ import org.elasticsearch.transport.TransportService;
|
|||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Vector;
|
||||
|
||||
/**
|
||||
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
|
||||
|
@ -85,8 +86,9 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
|||
RankedListQualityMetric metric = qualityTask.getEvaluator();
|
||||
|
||||
double qualitySum = 0;
|
||||
Map<String, Collection<String>> unknownDocs = new HashMap<String, Collection<String>>();
|
||||
Map<String, Collection<String>> unknownDocs = new HashMap<>();
|
||||
Collection<QuerySpec> specifications = qualityTask.getSpecifications();
|
||||
Vector<EvalQueryQuality> partialResults = new Vector<>(specifications.size());
|
||||
for (QuerySpec spec : specifications) {
|
||||
SearchSourceBuilder specRequest = spec.getTestRequest();
|
||||
String[] indices = new String[spec.getIndices().size()];
|
||||
|
@ -101,13 +103,14 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
|||
ActionFuture<SearchResponse> searchResponse = transportSearchAction.execute(templatedRequest);
|
||||
SearchHits hits = searchResponse.actionGet().getHits();
|
||||
|
||||
EvalQueryQuality intentQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
|
||||
qualitySum += intentQuality.getQualityLevel();
|
||||
unknownDocs.put(spec.getSpecId(), intentQuality.getUnknownDocs());
|
||||
EvalQueryQuality queryQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
|
||||
partialResults.addElement(queryQuality);
|
||||
unknownDocs.put(spec.getSpecId(), queryQuality.getUnknownDocs());
|
||||
}
|
||||
|
||||
RankEvalResponse response = new RankEvalResponse();
|
||||
// TODO move averaging to actual metric, also add other statistics
|
||||
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), qualitySum / specifications.size(), unknownDocs);
|
||||
// TODO add other statistics like micro/macro avg?
|
||||
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), metric.combine(partialResults), unknownDocs);
|
||||
response.setRankEvalResult(result);
|
||||
listener.onResponse(response);
|
||||
}
|
||||
|
|
|
@ -32,8 +32,11 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
public class PrecisionAtNTests extends ESTestCase {
|
||||
|
||||
public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException {
|
||||
|
@ -66,4 +69,13 @@ public class PrecisionAtNTests extends ESTestCase {
|
|||
PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
|
||||
assertEquals(10, precicionAt.getN());
|
||||
}
|
||||
|
||||
public void testCombine() {
|
||||
PrecisionAtN metric = new PrecisionAtN();
|
||||
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
|
||||
partialResults.add(new EvalQueryQuality(0.1, emptyList()));
|
||||
partialResults.add(new EvalQueryQuality(0.2, emptyList()));
|
||||
partialResults.add(new EvalQueryQuality(0.6, emptyList()));
|
||||
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.action.quality;
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.rankeval.PrecisionAtN;
|
|
@ -17,7 +17,7 @@
|
|||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.action.quality;
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
|
@ -28,8 +28,8 @@ import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException;
|
|||
|
||||
import java.io.IOException;
|
||||
|
||||
public class RankEvalRestIT extends ESClientYamlSuiteTestCase {
|
||||
public RankEvalRestIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
||||
public class RankEvalYamlIT extends ESClientYamlSuiteTestCase {
|
||||
public RankEvalYamlIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
||||
super(testCandidate);
|
||||
}
|
||||
|
|
@ -17,13 +17,10 @@
|
|||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.action.quality;
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.index.rankeval.EvalQueryQuality;
|
||||
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
|
||||
import org.elasticsearch.index.rankeval.RatedDocument;
|
||||
import org.elasticsearch.index.rankeval.ReciprocalRank;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.internal.InternalSearchHit;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
@ -31,6 +28,9 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
public class ReciprocalRankTests extends ESTestCase {
|
||||
|
||||
|
@ -86,6 +86,15 @@ public class ReciprocalRankTests extends ESTestCase {
|
|||
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
|
||||
}
|
||||
|
||||
public void testCombine() {
|
||||
ReciprocalRank reciprocalRank = new ReciprocalRank();
|
||||
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
|
||||
partialResults.add(new EvalQueryQuality(0.5, emptyList()));
|
||||
partialResults.add(new EvalQueryQuality(1.0, emptyList()));
|
||||
partialResults.add(new EvalQueryQuality(0.75, emptyList()));
|
||||
assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE);
|
||||
}
|
||||
|
||||
public void testEvaluationNoRelevantInResults() {
|
||||
ReciprocalRank reciprocalRank = new ReciprocalRank();
|
||||
SearchHit[] hits = new SearchHit[10];
|
Loading…
Reference in New Issue