mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-22 12:56:53 +00:00
Merge pull request #19686 from cbuescher/rankMetric-combine
Moving averaging of partial evaluation results to RankedListQualityMetric
This commit is contained in:
commit
acba915340
@ -111,7 +111,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
|
|||||||
|
|
||||||
int good = 0;
|
int good = 0;
|
||||||
int bad = 0;
|
int bad = 0;
|
||||||
Collection<String> unknownDocIds = new ArrayList<String>();
|
Collection<String> unknownDocIds = new ArrayList<>();
|
||||||
for (int i = 0; (i < n && i < hits.length); i++) {
|
for (int i = 0; (i < n && i < hits.length); i++) {
|
||||||
String id = hits[i].getId();
|
String id = hits[i].getId();
|
||||||
if (relevantDocIds.contains(id)) {
|
if (relevantDocIds.contains(id)) {
|
||||||
@ -122,9 +122,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
|
|||||||
unknownDocIds.add(id);
|
unknownDocIds.add(id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
double precision = (double) good / (good + bad);
|
double precision = (double) good / (good + bad);
|
||||||
|
|
||||||
return new EvalQueryQuality(precision, unknownDocIds);
|
return new EvalQueryQuality(precision, unknownDocIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ import org.elasticsearch.search.SearchHit;
|
|||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Vector;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Classes implementing this interface provide a means to compute the quality of a result list
|
* Classes implementing this interface provide a means to compute the quality of a result list
|
||||||
@ -71,4 +72,8 @@ public abstract class RankedListQualityMetric implements NamedWriteable {
|
|||||||
}
|
}
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double combine(Vector<EvalQueryQuality> partialResults) {
|
||||||
|
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -44,6 +44,7 @@ import org.elasticsearch.transport.TransportService;
|
|||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Vector;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
|
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
|
||||||
@ -85,8 +86,9 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
|||||||
RankedListQualityMetric metric = qualityTask.getEvaluator();
|
RankedListQualityMetric metric = qualityTask.getEvaluator();
|
||||||
|
|
||||||
double qualitySum = 0;
|
double qualitySum = 0;
|
||||||
Map<String, Collection<String>> unknownDocs = new HashMap<String, Collection<String>>();
|
Map<String, Collection<String>> unknownDocs = new HashMap<>();
|
||||||
Collection<QuerySpec> specifications = qualityTask.getSpecifications();
|
Collection<QuerySpec> specifications = qualityTask.getSpecifications();
|
||||||
|
Vector<EvalQueryQuality> partialResults = new Vector<>(specifications.size());
|
||||||
for (QuerySpec spec : specifications) {
|
for (QuerySpec spec : specifications) {
|
||||||
SearchSourceBuilder specRequest = spec.getTestRequest();
|
SearchSourceBuilder specRequest = spec.getTestRequest();
|
||||||
String[] indices = new String[spec.getIndices().size()];
|
String[] indices = new String[spec.getIndices().size()];
|
||||||
@ -101,13 +103,14 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
|||||||
ActionFuture<SearchResponse> searchResponse = transportSearchAction.execute(templatedRequest);
|
ActionFuture<SearchResponse> searchResponse = transportSearchAction.execute(templatedRequest);
|
||||||
SearchHits hits = searchResponse.actionGet().getHits();
|
SearchHits hits = searchResponse.actionGet().getHits();
|
||||||
|
|
||||||
EvalQueryQuality intentQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
|
EvalQueryQuality queryQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
|
||||||
qualitySum += intentQuality.getQualityLevel();
|
partialResults.addElement(queryQuality);
|
||||||
unknownDocs.put(spec.getSpecId(), intentQuality.getUnknownDocs());
|
unknownDocs.put(spec.getSpecId(), queryQuality.getUnknownDocs());
|
||||||
}
|
}
|
||||||
|
|
||||||
RankEvalResponse response = new RankEvalResponse();
|
RankEvalResponse response = new RankEvalResponse();
|
||||||
// TODO move averaging to actual metric, also add other statistics
|
// TODO add other statistics like micro/macro avg?
|
||||||
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), qualitySum / specifications.size(), unknownDocs);
|
RankEvalResult result = new RankEvalResult(qualityTask.getTaskId(), metric.combine(partialResults), unknownDocs);
|
||||||
response.setRankEvalResult(result);
|
response.setRankEvalResult(result);
|
||||||
listener.onResponse(response);
|
listener.onResponse(response);
|
||||||
}
|
}
|
||||||
|
@ -32,8 +32,11 @@ import java.io.IOException;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Vector;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
|
|
||||||
|
import static java.util.Collections.emptyList;
|
||||||
|
|
||||||
public class PrecisionAtNTests extends ESTestCase {
|
public class PrecisionAtNTests extends ESTestCase {
|
||||||
|
|
||||||
public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException {
|
public void testPrecisionAtFiveCalculation() throws IOException, InterruptedException, ExecutionException {
|
||||||
@ -66,4 +69,13 @@ public class PrecisionAtNTests extends ESTestCase {
|
|||||||
PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
|
PrecisionAtN precicionAt = PrecisionAtN.fromXContent(parser, () -> ParseFieldMatcher.STRICT);
|
||||||
assertEquals(10, precicionAt.getN());
|
assertEquals(10, precicionAt.getN());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testCombine() {
|
||||||
|
PrecisionAtN metric = new PrecisionAtN();
|
||||||
|
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
|
||||||
|
partialResults.add(new EvalQueryQuality(0.1, emptyList()));
|
||||||
|
partialResults.add(new EvalQueryQuality(0.2, emptyList()));
|
||||||
|
partialResults.add(new EvalQueryQuality(0.6, emptyList()));
|
||||||
|
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.action.quality;
|
package org.elasticsearch.index.rankeval;
|
||||||
|
|
||||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||||
import org.elasticsearch.index.rankeval.PrecisionAtN;
|
import org.elasticsearch.index.rankeval.PrecisionAtN;
|
@ -17,7 +17,7 @@
|
|||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.action.quality;
|
package org.elasticsearch.index.rankeval;
|
||||||
|
|
||||||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||||
@ -28,8 +28,8 @@ import org.elasticsearch.test.rest.yaml.parser.ClientYamlTestParseException;
|
|||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
public class RankEvalRestIT extends ESClientYamlSuiteTestCase {
|
public class RankEvalYamlIT extends ESClientYamlSuiteTestCase {
|
||||||
public RankEvalRestIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
public RankEvalYamlIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
|
||||||
super(testCandidate);
|
super(testCandidate);
|
||||||
}
|
}
|
||||||
|
|
@ -17,13 +17,10 @@
|
|||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.action.quality;
|
package org.elasticsearch.index.rankeval;
|
||||||
|
|
||||||
import org.elasticsearch.common.text.Text;
|
import org.elasticsearch.common.text.Text;
|
||||||
import org.elasticsearch.index.rankeval.EvalQueryQuality;
|
|
||||||
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
|
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
|
||||||
import org.elasticsearch.index.rankeval.RatedDocument;
|
|
||||||
import org.elasticsearch.index.rankeval.ReciprocalRank;
|
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.search.internal.InternalSearchHit;
|
import org.elasticsearch.search.internal.InternalSearchHit;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
@ -31,6 +28,9 @@ import org.elasticsearch.test.ESTestCase;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Vector;
|
||||||
|
|
||||||
|
import static java.util.Collections.emptyList;
|
||||||
|
|
||||||
public class ReciprocalRankTests extends ESTestCase {
|
public class ReciprocalRankTests extends ESTestCase {
|
||||||
|
|
||||||
@ -86,6 +86,15 @@ public class ReciprocalRankTests extends ESTestCase {
|
|||||||
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
|
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testCombine() {
|
||||||
|
ReciprocalRank reciprocalRank = new ReciprocalRank();
|
||||||
|
Vector<EvalQueryQuality> partialResults = new Vector<>(3);
|
||||||
|
partialResults.add(new EvalQueryQuality(0.5, emptyList()));
|
||||||
|
partialResults.add(new EvalQueryQuality(1.0, emptyList()));
|
||||||
|
partialResults.add(new EvalQueryQuality(0.75, emptyList()));
|
||||||
|
assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE);
|
||||||
|
}
|
||||||
|
|
||||||
public void testEvaluationNoRelevantInResults() {
|
public void testEvaluationNoRelevantInResults() {
|
||||||
ReciprocalRank reciprocalRank = new ReciprocalRank();
|
ReciprocalRank reciprocalRank = new ReciprocalRank();
|
||||||
SearchHit[] hits = new SearchHit[10];
|
SearchHit[] hits = new SearchHit[10];
|
Loading…
x
Reference in New Issue
Block a user