Merge pull request #20442 from cbuescher/rankEval-removeTopLevelId
RankEval: Remove top level `spec_id`
This commit is contained in:
commit
10d465f946
|
@ -44,8 +44,6 @@ import java.util.Objects;
|
|||
**/
|
||||
//TODO instead of just returning averages over complete results, think of other statistics, micro avg, macro avg, partial results
|
||||
public class RankEvalResponse extends ActionResponse implements ToXContent {
|
||||
/**ID of QA specification this result was generated for.*/
|
||||
private String specId;
|
||||
/**Average precision observed when issuing query intents with this specification.*/
|
||||
private double qualityLevel;
|
||||
/**Mapping from intent id to all documents seen for this intent that were not annotated.*/
|
||||
|
@ -54,17 +52,11 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
|
|||
public RankEvalResponse() {
|
||||
}
|
||||
|
||||
public RankEvalResponse(String specId, double qualityLevel, Map<String, Collection<RatedDocumentKey>> unknownDocs) {
|
||||
this.specId = specId;
|
||||
public RankEvalResponse(double qualityLevel, Map<String, Collection<RatedDocumentKey>> unknownDocs) {
|
||||
this.qualityLevel = qualityLevel;
|
||||
this.unknownDocs = unknownDocs;
|
||||
}
|
||||
|
||||
|
||||
public String getSpecId() {
|
||||
return specId;
|
||||
}
|
||||
|
||||
public double getQualityLevel() {
|
||||
return qualityLevel;
|
||||
}
|
||||
|
@ -75,13 +67,12 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
|
|||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "RankEvalResponse, ID :[" + specId + "], quality: " + qualityLevel + ", unknown docs: " + unknownDocs;
|
||||
return "RankEvalResponse, quality: " + qualityLevel + ", unknown docs: " + unknownDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeString(specId);
|
||||
out.writeDouble(qualityLevel);
|
||||
out.writeVInt(unknownDocs.size());
|
||||
for (String queryId : unknownDocs.keySet()) {
|
||||
|
@ -97,7 +88,6 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
|
|||
@Override
|
||||
public void readFrom(StreamInput in) throws IOException {
|
||||
super.readFrom(in);
|
||||
this.specId = in.readString();
|
||||
this.qualityLevel = in.readDouble();
|
||||
int unknownDocumentSets = in.readVInt();
|
||||
this.unknownDocs = new HashMap<>(unknownDocumentSets);
|
||||
|
@ -115,7 +105,6 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject("rank_eval");
|
||||
builder.field("spec_id", specId);
|
||||
builder.field("quality_level", qualityLevel);
|
||||
builder.startObject("unknown_docs");
|
||||
for (String key : unknownDocs.keySet()) {
|
||||
|
@ -144,13 +133,12 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
|
|||
return false;
|
||||
}
|
||||
RankEvalResponse other = (RankEvalResponse) obj;
|
||||
return Objects.equals(specId, other.specId) &&
|
||||
Objects.equals(qualityLevel, other.qualityLevel) &&
|
||||
return Objects.equals(qualityLevel, other.qualityLevel) &&
|
||||
Objects.equals(unknownDocs, other.unknownDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(getClass(), specId, qualityLevel, unknownDocs);
|
||||
return Objects.hash(getClass(), qualityLevel, unknownDocs);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,15 +47,12 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
|
|||
private Collection<RatedRequest> ratedRequests = new ArrayList<>();
|
||||
/** Definition of the quality metric, e.g. precision at N */
|
||||
private RankedListQualityMetric metric;
|
||||
/** a unique id for the whole QA task */
|
||||
private String specId;
|
||||
|
||||
public RankEvalSpec() {
|
||||
// TODO think if no args ctor is okay
|
||||
}
|
||||
|
||||
public RankEvalSpec(String specId, Collection<RatedRequest> specs, RankedListQualityMetric metric) {
|
||||
this.specId = specId;
|
||||
public RankEvalSpec(Collection<RatedRequest> specs, RankedListQualityMetric metric) {
|
||||
this.ratedRequests = specs;
|
||||
this.metric = metric;
|
||||
}
|
||||
|
@ -67,7 +64,6 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
|
|||
ratedRequests.add(new RatedRequest(in));
|
||||
}
|
||||
metric = in.readNamedWriteable(RankedListQualityMetric.class);
|
||||
specId = in.readString();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -77,21 +73,12 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
|
|||
spec.writeTo(out);
|
||||
}
|
||||
out.writeNamedWriteable(metric);
|
||||
out.writeString(specId);
|
||||
}
|
||||
|
||||
public void setEval(RankedListQualityMetric eval) {
|
||||
this.metric = eval;
|
||||
}
|
||||
|
||||
public void setTaskId(String taskId) {
|
||||
this.specId = taskId;
|
||||
}
|
||||
|
||||
public String getTaskId() {
|
||||
return this.specId;
|
||||
}
|
||||
|
||||
/** Returns the precision at n configuration (containing level of n to consider).*/
|
||||
public RankedListQualityMetric getEvaluator() {
|
||||
return metric;
|
||||
|
@ -112,13 +99,11 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
|
|||
this.ratedRequests = specifications;
|
||||
}
|
||||
|
||||
private static final ParseField SPECID_FIELD = new ParseField("spec_id");
|
||||
private static final ParseField METRIC_FIELD = new ParseField("metric");
|
||||
private static final ParseField REQUESTS_FIELD = new ParseField("requests");
|
||||
private static final ObjectParser<RankEvalSpec, RankEvalContext> PARSER = new ObjectParser<>("rank_eval", RankEvalSpec::new);
|
||||
|
||||
static {
|
||||
PARSER.declareString(RankEvalSpec::setTaskId, SPECID_FIELD);
|
||||
PARSER.declareObject(RankEvalSpec::setEvaluator, (p, c) -> {
|
||||
try {
|
||||
return RankedListQualityMetric.fromXContent(p, c);
|
||||
|
@ -138,7 +123,6 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(SPECID_FIELD.getPreferredName(), this.specId);
|
||||
builder.startArray(REQUESTS_FIELD.getPreferredName());
|
||||
for (RatedRequest spec : this.ratedRequests) {
|
||||
spec.toXContent(builder, params);
|
||||
|
@ -162,13 +146,12 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
|
|||
return false;
|
||||
}
|
||||
RankEvalSpec other = (RankEvalSpec) obj;
|
||||
return Objects.equals(specId, other.specId) &&
|
||||
Objects.equals(ratedRequests, other.ratedRequests) &&
|
||||
return Objects.equals(ratedRequests, other.ratedRequests) &&
|
||||
Objects.equals(metric, other.metric);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(specId, ratedRequests, metric);
|
||||
return Objects.hash(ratedRequests, metric);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -112,7 +112,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
|||
if (responseCounter.decrementAndGet() < 1) {
|
||||
// TODO add other statistics like micro/macro avg?
|
||||
listener.onResponse(
|
||||
new RankEvalResponse(task.getTaskId(), task.getEvaluator().combine(partialResults.values()), unknownDocs));
|
||||
new RankEvalResponse(task.getEvaluator().combine(partialResults.values()), unknownDocs));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -70,20 +70,18 @@ public class RankEvalRequestTests extends ESIntegTestCase {
|
|||
List<String> indices = Arrays.asList(new String[] { "test" });
|
||||
List<String> types = Arrays.asList(new String[] { "testtype" });
|
||||
|
||||
String specId = randomAsciiOfLength(10);
|
||||
List<RatedRequest> specifications = new ArrayList<>();
|
||||
SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
||||
testQuery.query(new MatchAllQueryBuilder());
|
||||
specifications.add(new RatedRequest("amsterdam_query", testQuery, indices, types, createRelevant("2", "3", "4", "5")));
|
||||
specifications.add(new RatedRequest("berlin_query", testQuery, indices, types, createRelevant("1")));
|
||||
|
||||
RankEvalSpec task = new RankEvalSpec(specId, specifications, new PrecisionAtN(10));
|
||||
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtN(10));
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
assertEquals(specId, response.getSpecId());
|
||||
assertEquals(1.0, response.getQualityLevel(), Double.MIN_VALUE);
|
||||
Set<Entry<String, Collection<RatedDocumentKey>>> entrySet = response.getUnknownDocs().entrySet();
|
||||
assertEquals(2, entrySet.size());
|
||||
|
@ -104,7 +102,6 @@ public class RankEvalRequestTests extends ESIntegTestCase {
|
|||
List<String> indices = Arrays.asList(new String[] { "test" });
|
||||
List<String> types = Arrays.asList(new String[] { "testtype" });
|
||||
|
||||
String specId = randomAsciiOfLength(10);
|
||||
List<RatedRequest> specifications = new ArrayList<>();
|
||||
SearchSourceBuilder amsterdamQuery = new SearchSourceBuilder();
|
||||
amsterdamQuery.query(new MatchAllQueryBuilder());
|
||||
|
@ -114,7 +111,7 @@ public class RankEvalRequestTests extends ESIntegTestCase {
|
|||
brokenQuery.query(brokenRangeQuery);
|
||||
specifications.add(new RatedRequest("broken_query", brokenQuery, indices, types, createRelevant("1")));
|
||||
|
||||
RankEvalSpec task = new RankEvalSpec(specId, specifications, new PrecisionAtN(10));
|
||||
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtN(10));
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
|
|
@ -47,7 +47,7 @@ public class RankEvalResponseTests extends ESTestCase {
|
|||
}
|
||||
unknownDocs.put(randomAsciiOfLength(5), ids);
|
||||
}
|
||||
return new RankEvalResponse(randomAsciiOfLengthBetween(1, 10), randomDouble(), unknownDocs );
|
||||
return new RankEvalResponse(randomDouble(), unknownDocs );
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
|
|
|
@ -80,7 +80,6 @@ public class RankEvalSpecTests extends ESTestCase {
|
|||
specs.add(RatedRequestsTests.createTestItem(indices, types));
|
||||
}
|
||||
|
||||
String specId = randomAsciiOfLengthBetween(1, 10); // TODO we should reject zero length ids ...
|
||||
RankedListQualityMetric metric;
|
||||
if (randomBoolean()) {
|
||||
metric = PrecisionAtNTests.createTestItem();
|
||||
|
@ -88,7 +87,7 @@ public class RankEvalSpecTests extends ESTestCase {
|
|||
metric = DiscountedCumulativeGainAtTests.createTestItem();
|
||||
}
|
||||
|
||||
RankEvalSpec testItem = new RankEvalSpec(specId, specs, metric);
|
||||
RankEvalSpec testItem = new RankEvalSpec(specs, metric);
|
||||
|
||||
XContentParser itemParser = XContentTestHelper.roundtrip(testItem);
|
||||
|
||||
|
|
|
@ -40,8 +40,7 @@
|
|||
|
||||
- do:
|
||||
rank_eval:
|
||||
body: {
|
||||
"spec_id" : "cities_qa_queries",
|
||||
body: {
|
||||
"requests" : [
|
||||
{
|
||||
"id": "amsterdam_query",
|
||||
|
@ -60,7 +59,6 @@
|
|||
"metric" : { "precisionatn": { "size": 10}}
|
||||
}
|
||||
|
||||
- match: {rank_eval.spec_id: "cities_qa_queries"}
|
||||
- match: {rank_eval.quality_level: 1}
|
||||
- match: {rank_eval.unknown_docs.amsterdam_query: [ {"index": "foo", "type": "bar", "doc_id": "doc4"}]}
|
||||
- match: {rank_eval.unknown_docs.berlin_query: [ {"index": "foo", "type": "bar", "doc_id": "doc4"}]}
|
||||
|
@ -107,8 +105,7 @@
|
|||
|
||||
- do:
|
||||
rank_eval:
|
||||
body: {
|
||||
"spec_id" : "cities_qa_queries",
|
||||
body: {
|
||||
"requests" : [
|
||||
{
|
||||
"id": "amsterdam_query",
|
||||
|
@ -126,14 +123,12 @@
|
|||
"metric" : { "reciprocal_rank": {} }
|
||||
}
|
||||
|
||||
- match: {rank_eval.spec_id: "cities_qa_queries"}
|
||||
# average is (1/3 + 1/2)/2 = 5/12 ~ 0.41666666666666663
|
||||
- match: {rank_eval.quality_level: 0.41666666666666663}
|
||||
|
||||
- do:
|
||||
rank_eval:
|
||||
body: {
|
||||
"spec_id" : "cities_qa_queries",
|
||||
"requests" : [
|
||||
{
|
||||
"id": "amsterdam_query",
|
||||
|
@ -155,6 +150,5 @@
|
|||
}
|
||||
}
|
||||
|
||||
- match: {rank_eval.spec_id: "cities_qa_queries"}
|
||||
# average is (0 + 1/2)/2 = 1/4
|
||||
- match: {rank_eval.quality_level: 0.25}
|
||||
|
|
|
@ -47,7 +47,6 @@
|
|||
- do:
|
||||
rank_eval:
|
||||
body: {
|
||||
"spec_id" : "dcg_qa_queries",
|
||||
"requests" : [
|
||||
{
|
||||
"id": "dcg_query",
|
||||
|
@ -64,7 +63,6 @@
|
|||
"metric" : { "dcg_at_n": { "size": 6}}
|
||||
}
|
||||
|
||||
- match: {rank_eval.spec_id: "dcg_qa_queries"}
|
||||
- match: {rank_eval.quality_level: 13.84826362927298}
|
||||
|
||||
# reverse the order in which the results are returned (less relevant docs first)
|
||||
|
@ -72,7 +70,6 @@
|
|||
- do:
|
||||
rank_eval:
|
||||
body: {
|
||||
"spec_id" : "dcg_qa_queries",
|
||||
"requests" : [
|
||||
{
|
||||
"id": "dcg_query_reverse",
|
||||
|
@ -89,7 +86,6 @@
|
|||
"metric" : { "dcg_at_n": { "size": 6}}
|
||||
}
|
||||
|
||||
- match: {rank_eval.spec_id: "dcg_qa_queries"}
|
||||
- match: {rank_eval.quality_level: 10.29967439154499}
|
||||
|
||||
# if we mix both, we should get the average
|
||||
|
@ -97,7 +93,6 @@
|
|||
- do:
|
||||
rank_eval:
|
||||
body: {
|
||||
"spec_id" : "dcg_qa_queries",
|
||||
"requests" : [
|
||||
{
|
||||
"id": "dcg_query",
|
||||
|
@ -125,5 +120,4 @@
|
|||
"metric" : { "dcg_at_n": { "size": 6}}
|
||||
}
|
||||
|
||||
- match: {rank_eval.spec_id: "dcg_qa_queries"}
|
||||
- match: {rank_eval.quality_level: 12.073969010408984}
|
||||
|
|
Loading…
Reference in New Issue