mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-18 19:05:06 +00:00
Merge pull request #19686 from cbuescher/rankMetric-combine
Moving averaging of partial evaluation results to RankedListQualityMetric
This commit is contained in:
commit
acba915340
@ -111,7 +111,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
|
||||
|
||||
int good = 0;
|
||||
int bad = 0;
|
||||
Collection<String> unknownDocIds = new ArrayList<String>();
|
||||
Collection<String> unknownDocIds = new ArrayList<>();
|
||||
for (int i = 0; (i < n && i < hits.length); i++) {
|
||||
String id = hits[i].getId();
|
||||
if (relevantDocIds.contains(id)) {
|
||||
@ -122,9 +122,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
|
||||
unknownDocIds.add(id);
|
||||
}
|
||||
}
|
||||
|
||||
double precision = (double) good / (good + bad);
|
||||
|
||||
return new EvalQueryQuality(precision, unknownDocIds);
|
||||
}
|
||||
|
||||
|
@ -28,6 +28,7 @@ import org.elasticsearch.search.SearchHit;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
|
||||
/**
|
||||
* Classes implementing this interface provide a means to compute the quality of a result list
|
||||
@ -71,4 +72,8 @@ public abstract class RankedListQualityMetric implements NamedWriteable {
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
double combine(Vector<EvalQueryQuality> partialResults) {
|
||||
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
|
||||
}
|
||||
}
|
||||
|
@ -44,6 +44,7 @@ import org.elasticsearch.transport.TransportService;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Vector;
|
||||
|
||||
/**
|
||||
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
|
||||
@ -85,8 +86,9 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
||||
RankedListQualityMetric metric = qualityTask.getEvaluator();
|
||||
|
||||
double qualitySum = 0;
|
||||
Map<String, Collection<String>> unknownDocs = new HashMap<String, Collection<String>>();
|
||||
Map<String, Collection<String>> unknownDocs = new HashMap<>();
|
||||
Collection<QuerySpec> specifications = qualityTask.getSpecifications();
|
||||
Vector<EvalQueryQuality> partialResults = new Vector<>(specifications.size());
|
||||
for (QuerySpec spec : specifications) {
|
||||
SearchSourceBuilder specRequest = spec.getTestRequest();
|
||||
String[] indices = new String[spec.getIndices().size()];
|
||||
@ -101,13 +103,14 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
||||
ActionFuture<SearchResponse> searchResponse = transportSearchAction.execute(templatedRequest);
|
||||
SearchHits hits = searchResponse.actionGet().getHits();
|
||||
|
||||
EvalQueryQuality intentQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
|
||||
qualitySum += intentQuality.getQualityLevel();
|
||||
unknownDocs.put(spec.getSpecId(), intentQuality.getUnknownDocs());
|
||||
EvalQueryQuality queryQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
|
||||
partialResults.addElement(queryQuality);
|
||||
unknownDocs.put(spec.getSpecId(), queryQuality.getUnknownDocs());
|
||||
}
|
||||
|
||||
RankEvalResponse response = new RankEvalResponse();
|
||||
// TODO move averaging to actual metric, also add other statistics
|
||||
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), qualitySum / specifications.size(), unknownDocs);
|
||||
// TODO add other statistics like micro/macro avg?
|
||||
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), metric.combine(partialResults), unknownDocs);
|
||||
response.setRankEvalResult(result);
|
||||
listener.onResponse(response);
|
||||
}
|
||||
|
@ -32,8 +32,11 @@ import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
public class PrecisionAtNTests extends ESTestCase {
|
||||
|
||||
public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException {
|
||||
@ -66,4 +69,13 @@ public class PrecisionAtNTests extends ESTestCase {
|
||||
PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
|
||||
assertEquals(10, precicionAt.getN());
|
||||
}
|
||||
|
||||
public void testCombine() {
|
||||
PrecisionAtN metric = new PrecisionAtN();
|
||||
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
|
||||
partialResults.add(new EvalQueryQuality(0.1, emptyList()));
|
||||
partialResults.add(new EvalQueryQuality(0.2, emptyList()));
|
||||
partialResults.add(new EvalQueryQuality(0.6, emptyList()));
|
||||
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
|
||||
}
|
||||
}
|
||||
|
@ -17,7 +17,7 @@
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.action.quality;
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.rankeval.PrecisionAtN;
|
@ -17,7 +17,7 @@
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.action.quality;
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
@ -28,8 +28,8 @@ import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class RankEvalRestIT extends ESClientYamlSuiteTestCase {
|
||||
public RankEvalRestIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
||||
public class RankEvalYamlIT extends ESClientYamlSuiteTestCase {
|
||||
public RankEvalYamlIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
||||
super(testCandidate);
|
||||
}
|
||||
|
@ -17,13 +17,10 @@
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.action.quality;
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.index.rankeval.EvalQueryQuality;
|
||||
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
|
||||
import org.elasticsearch.index.rankeval.RatedDocument;
|
||||
import org.elasticsearch.index.rankeval.ReciprocalRank;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.internal.InternalSearchHit;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
@ -31,6 +28,9 @@ import org.elasticsearch.test.ESTestCase;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
public class ReciprocalRankTests extends ESTestCase {
|
||||
|
||||
@ -86,6 +86,15 @@ public class ReciprocalRankTests extends ESTestCase {
|
||||
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
|
||||
}
|
||||
|
||||
public void testCombine() {
|
||||
ReciprocalRank reciprocalRank = new ReciprocalRank();
|
||||
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
|
||||
partialResults.add(new EvalQueryQuality(0.5, emptyList()));
|
||||
partialResults.add(new EvalQueryQuality(1.0, emptyList()));
|
||||
partialResults.add(new EvalQueryQuality(0.75, emptyList()));
|
||||
assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE);
|
||||
}
|
||||
|
||||
public void testEvaluationNoRelevantInResults() {
|
||||
ReciprocalRank reciprocalRank = new ReciprocalRank();
|
||||
SearchHit[] hits = new SearchHit[10];
|
Loading…
x
Reference in New Issue
Block a user