Moving averaging of partial evaluation results to RankedListQualityMetric
For the two current metrics Prec@ and reciprocal rank we currently average the partial results in the transport action. If other metric later need a different behaviour or want to parametrize this, this operation should be part of the metric itself, so this change moves it there. Also removing on of the two test packages, main code is also in one package only.
This commit is contained in:
parent
0fb7dd9054
commit
d71dc205fa
|
@ -111,7 +111,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
|
||||||
|
|
||||||
int good = 0;
|
int good = 0;
|
||||||
int bad = 0;
|
int bad = 0;
|
||||||
Collection<String> unknownDocIds = new ArrayList<String>();
|
Collection<String> unknownDocIds = new ArrayList<>();
|
||||||
for (int i = 0; (i < n && i < hits.length); i++) {
|
for (int i = 0; (i < n && i < hits.length); i++) {
|
||||||
String id = hits[i].getId();
|
String id = hits[i].getId();
|
||||||
if (relevantDocIds.contains(id)) {
|
if (relevantDocIds.contains(id)) {
|
||||||
|
@ -122,9 +122,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
|
||||||
unknownDocIds.add(id);
|
unknownDocIds.add(id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
double precision = (double) good / (good + bad);
|
double precision = (double) good / (good + bad);
|
||||||
|
|
||||||
return new EvalQueryQuality(precision, unknownDocIds);
|
return new EvalQueryQuality(precision, unknownDocIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.elasticsearch.search.SearchHit;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Vector;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Classes implementing this interface provide a means to compute the quality of a result list
|
* Classes implementing this interface provide a means to compute the quality of a result list
|
||||||
|
@ -71,4 +72,8 @@ public abstract class RankedListQualityMetric implements NamedWriteable {
|
||||||
}
|
}
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double combine(Vector<EvalQueryQuality> partialResults) {
|
||||||
|
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,7 @@ import org.elasticsearch.transport.TransportService;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Vector;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
|
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
|
||||||
|
@ -85,8 +86,9 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
||||||
RankedListQualityMetric metric = qualityTask.getEvaluator();
|
RankedListQualityMetric metric = qualityTask.getEvaluator();
|
||||||
|
|
||||||
double qualitySum = 0;
|
double qualitySum = 0;
|
||||||
Map<String, Collection<String>> unknownDocs = new HashMap<String, Collection<String>>();
|
Map<String, Collection<String>> unknownDocs = new HashMap<>();
|
||||||
Collection<QuerySpec> specifications = qualityTask.getSpecifications();
|
Collection<QuerySpec> specifications = qualityTask.getSpecifications();
|
||||||
|
Vector<EvalQueryQuality> partialResults = new Vector<>(specifications.size());
|
||||||
for (QuerySpec spec : specifications) {
|
for (QuerySpec spec : specifications) {
|
||||||
SearchSourceBuilder specRequest = spec.getTestRequest();
|
SearchSourceBuilder specRequest = spec.getTestRequest();
|
||||||
String[] indices = new String[spec.getIndices().size()];
|
String[] indices = new String[spec.getIndices().size()];
|
||||||
|
@ -101,13 +103,14 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
||||||
ActionFuture<SearchResponse> searchResponse = transportSearchAction.execute(templatedRequest);
|
ActionFuture<SearchResponse> searchResponse = transportSearchAction.execute(templatedRequest);
|
||||||
SearchHits hits = searchResponse.actionGet().getHits();
|
SearchHits hits = searchResponse.actionGet().getHits();
|
||||||
|
|
||||||
EvalQueryQuality intentQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
|
EvalQueryQuality queryQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
|
||||||
qualitySum += intentQuality.getQualityLevel();
|
partialResults.addElement(queryQuality);
|
||||||
unknownDocs.put(spec.getSpecId(), intentQuality.getUnknownDocs());
|
unknownDocs.put(spec.getSpecId(), queryQuality.getUnknownDocs());
|
||||||
}
|
}
|
||||||
|
|
||||||
RankEvalResponse response = new RankEvalResponse();
|
RankEvalResponse response = new RankEvalResponse();
|
||||||
// TODO move averaging to actual metric, also add other statistics
|
// TODO add other statistics like micro/macro avg?
|
||||||
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), qualitySum / specifications.size(), unknownDocs);
|
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), metric.combine(partialResults), unknownDocs);
|
||||||
response.setRankEvalResult(result);
|
response.setRankEvalResult(result);
|
||||||
listener.onResponse(response);
|
listener.onResponse(response);
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,8 +32,11 @@ import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Vector;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
|
|
||||||
|
import static java.util.Collections.emptyList;
|
||||||
|
|
||||||
public class PrecisionAtNTests extends ESTestCase {
|
public class PrecisionAtNTests extends ESTestCase {
|
||||||
|
|
||||||
public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException {
|
public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException {
|
||||||
|
@ -66,4 +69,13 @@ public class PrecisionAtNTests extends ESTestCase {
|
||||||
PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
|
PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
|
||||||
assertEquals(10, precicionAt.getN());
|
assertEquals(10, precicionAt.getN());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testCombine() {
|
||||||
|
PrecisionAtN metric = new PrecisionAtN();
|
||||||
|
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
|
||||||
|
partialResults.add(new EvalQueryQuality(0.1, emptyList()));
|
||||||
|
partialResults.add(new EvalQueryQuality(0.2, emptyList()));
|
||||||
|
partialResults.add(new EvalQueryQuality(0.6, emptyList()));
|
||||||
|
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.action.quality;
|
package org.elasticsearch.index.rankeval;
|
||||||
|
|
||||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||||
import org.elasticsearch.index.rankeval.PrecisionAtN;
|
import org.elasticsearch.index.rankeval.PrecisionAtN;
|
|
@ -17,7 +17,7 @@
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.action.quality;
|
package org.elasticsearch.index.rankeval;
|
||||||
|
|
||||||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||||
|
@ -28,8 +28,8 @@ import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
public class RankEvalRestIT extends ESClientYamlSuiteTestCase {
|
public class RankEvalYamlIT extends ESClientYamlSuiteTestCase {
|
||||||
public RankEvalRestIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
public RankEvalYamlIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
||||||
super(testCandidate);
|
super(testCandidate);
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,13 +17,10 @@
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.action.quality;
|
package org.elasticsearch.index.rankeval;
|
||||||
|
|
||||||
import org.elasticsearch.common.text.Text;
|
import org.elasticsearch.common.text.Text;
|
||||||
import org.elasticsearch.index.rankeval.EvalQueryQuality;
|
|
||||||
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
|
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
|
||||||
import org.elasticsearch.index.rankeval.RatedDocument;
|
|
||||||
import org.elasticsearch.index.rankeval.ReciprocalRank;
|
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.search.internal.InternalSearchHit;
|
import org.elasticsearch.search.internal.InternalSearchHit;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
@ -31,6 +28,9 @@ import org.elasticsearch.test.ESTestCase;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Vector;
|
||||||
|
|
||||||
|
import static java.util.Collections.emptyList;
|
||||||
|
|
||||||
public class ReciprocalRankTests extends ESTestCase {
|
public class ReciprocalRankTests extends ESTestCase {
|
||||||
|
|
||||||
|
@ -86,6 +86,15 @@ public class ReciprocalRankTests extends ESTestCase {
|
||||||
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
|
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testCombine() {
|
||||||
|
ReciprocalRank reciprocalRank = new ReciprocalRank();
|
||||||
|
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
|
||||||
|
partialResults.add(new EvalQueryQuality(0.5, emptyList()));
|
||||||
|
partialResults.add(new EvalQueryQuality(1.0, emptyList()));
|
||||||
|
partialResults.add(new EvalQueryQuality(0.75, emptyList()));
|
||||||
|
assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE);
|
||||||
|
}
|
||||||
|
|
||||||
public void testEvaluationNoRelevantInResults() {
|
public void testEvaluationNoRelevantInResults() {
|
||||||
ReciprocalRank reciprocalRank = new ReciprocalRank();
|
ReciprocalRank reciprocalRank = new ReciprocalRank();
|
||||||
SearchHit[] hits = new SearchHit[10];
|
SearchHit[] hits = new SearchHit[10];
|
Loading…
Reference in New Issue