Merge pull request #20442 from cbuescher/rankEval-removeTopLevelId

RankEval: Remove top level `spec_id`
This commit is contained in:
Christoph Büscher 2016-09-19 12:34:14 +02:00 committed by GitHub
commit 10d465f946
8 changed files with 15 additions and 60 deletions

View File

@ -44,8 +44,6 @@ import java.util.Objects;
**/ **/
//TODO instead of just returning averages over complete results, think of other statistics, micro avg, macro avg, partial results //TODO instead of just returning averages over complete results, think of other statistics, micro avg, macro avg, partial results
public class RankEvalResponse extends ActionResponse implements ToXContent { public class RankEvalResponse extends ActionResponse implements ToXContent {
/**ID of QA specification this result was generated for.*/
private String specId;
/**Average precision observed when issuing query intents with this specification.*/ /**Average precision observed when issuing query intents with this specification.*/
private double qualityLevel; private double qualityLevel;
/**Mapping from intent id to all documents seen for this intent that were not annotated.*/ /**Mapping from intent id to all documents seen for this intent that were not annotated.*/
@ -54,17 +52,11 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
public RankEvalResponse() { public RankEvalResponse() {
} }
public RankEvalResponse(String specId, double qualityLevel, Map<String, Collection<RatedDocumentKey>> unknownDocs) { public RankEvalResponse(double qualityLevel, Map<String, Collection<RatedDocumentKey>> unknownDocs) {
this.specId = specId;
this.qualityLevel = qualityLevel; this.qualityLevel = qualityLevel;
this.unknownDocs = unknownDocs; this.unknownDocs = unknownDocs;
} }
public String getSpecId() {
return specId;
}
public double getQualityLevel() { public double getQualityLevel() {
return qualityLevel; return qualityLevel;
} }
@ -75,13 +67,12 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
@Override @Override
public String toString() { public String toString() {
return "RankEvalResponse, ID :[" + specId + "], quality: " + qualityLevel + ", unknown docs: " + unknownDocs; return "RankEvalResponse, quality: " + qualityLevel + ", unknown docs: " + unknownDocs;
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out); super.writeTo(out);
out.writeString(specId);
out.writeDouble(qualityLevel); out.writeDouble(qualityLevel);
out.writeVInt(unknownDocs.size()); out.writeVInt(unknownDocs.size());
for (String queryId : unknownDocs.keySet()) { for (String queryId : unknownDocs.keySet()) {
@ -97,7 +88,6 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
@Override @Override
public void readFrom(StreamInput in) throws IOException { public void readFrom(StreamInput in) throws IOException {
super.readFrom(in); super.readFrom(in);
this.specId = in.readString();
this.qualityLevel = in.readDouble(); this.qualityLevel = in.readDouble();
int unknownDocumentSets = in.readVInt(); int unknownDocumentSets = in.readVInt();
this.unknownDocs = new HashMap<>(unknownDocumentSets); this.unknownDocs = new HashMap<>(unknownDocumentSets);
@ -115,7 +105,6 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject("rank_eval"); builder.startObject("rank_eval");
builder.field("spec_id", specId);
builder.field("quality_level", qualityLevel); builder.field("quality_level", qualityLevel);
builder.startObject("unknown_docs"); builder.startObject("unknown_docs");
for (String key : unknownDocs.keySet()) { for (String key : unknownDocs.keySet()) {
@ -144,13 +133,12 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
return false; return false;
} }
RankEvalResponse other = (RankEvalResponse) obj; RankEvalResponse other = (RankEvalResponse) obj;
return Objects.equals(specId, other.specId) && return Objects.equals(qualityLevel, other.qualityLevel) &&
Objects.equals(qualityLevel, other.qualityLevel) &&
Objects.equals(unknownDocs, other.unknownDocs); Objects.equals(unknownDocs, other.unknownDocs);
} }
@Override @Override
public final int hashCode() { public final int hashCode() {
return Objects.hash(getClass(), specId, qualityLevel, unknownDocs); return Objects.hash(getClass(), qualityLevel, unknownDocs);
} }
} }

View File

@ -47,15 +47,12 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
private Collection<RatedRequest> ratedRequests = new ArrayList<>(); private Collection<RatedRequest> ratedRequests = new ArrayList<>();
/** Definition of the quality metric, e.g. precision at N */ /** Definition of the quality metric, e.g. precision at N */
private RankedListQualityMetric metric; private RankedListQualityMetric metric;
/** a unique id for the whole QA task */
private String specId;
public RankEvalSpec() { public RankEvalSpec() {
// TODO think if no args ctor is okay // TODO think if no args ctor is okay
} }
public RankEvalSpec(String specId, Collection<RatedRequest> specs, RankedListQualityMetric metric) { public RankEvalSpec(Collection<RatedRequest> specs, RankedListQualityMetric metric) {
this.specId = specId;
this.ratedRequests = specs; this.ratedRequests = specs;
this.metric = metric; this.metric = metric;
} }
@ -67,7 +64,6 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
ratedRequests.add(new RatedRequest(in)); ratedRequests.add(new RatedRequest(in));
} }
metric = in.readNamedWriteable(RankedListQualityMetric.class); metric = in.readNamedWriteable(RankedListQualityMetric.class);
specId = in.readString();
} }
@Override @Override
@ -77,21 +73,12 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
spec.writeTo(out); spec.writeTo(out);
} }
out.writeNamedWriteable(metric); out.writeNamedWriteable(metric);
out.writeString(specId);
} }
public void setEval(RankedListQualityMetric eval) { public void setEval(RankedListQualityMetric eval) {
this.metric = eval; this.metric = eval;
} }
public void setTaskId(String taskId) {
this.specId = taskId;
}
public String getTaskId() {
return this.specId;
}
/** Returns the precision at n configuration (containing level of n to consider).*/ /** Returns the precision at n configuration (containing level of n to consider).*/
public RankedListQualityMetric getEvaluator() { public RankedListQualityMetric getEvaluator() {
return metric; return metric;
@ -112,13 +99,11 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
this.ratedRequests = specifications; this.ratedRequests = specifications;
} }
private static final ParseField SPECID_FIELD = new ParseField("spec_id");
private static final ParseField METRIC_FIELD = new ParseField("metric"); private static final ParseField METRIC_FIELD = new ParseField("metric");
private static final ParseField REQUESTS_FIELD = new ParseField("requests"); private static final ParseField REQUESTS_FIELD = new ParseField("requests");
private static final ObjectParser<RankEvalSpec, RankEvalContext> PARSER = new ObjectParser<>("rank_eval", RankEvalSpec::new); private static final ObjectParser<RankEvalSpec, RankEvalContext> PARSER = new ObjectParser<>("rank_eval", RankEvalSpec::new);
static { static {
PARSER.declareString(RankEvalSpec::setTaskId, SPECID_FIELD);
PARSER.declareObject(RankEvalSpec::setEvaluator, (p, c) -> { PARSER.declareObject(RankEvalSpec::setEvaluator, (p, c) -> {
try { try {
return RankedListQualityMetric.fromXContent(p, c); return RankedListQualityMetric.fromXContent(p, c);
@ -138,7 +123,6 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(SPECID_FIELD.getPreferredName(), this.specId);
builder.startArray(REQUESTS_FIELD.getPreferredName()); builder.startArray(REQUESTS_FIELD.getPreferredName());
for (RatedRequest spec : this.ratedRequests) { for (RatedRequest spec : this.ratedRequests) {
spec.toXContent(builder, params); spec.toXContent(builder, params);
@ -162,13 +146,12 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
return false; return false;
} }
RankEvalSpec other = (RankEvalSpec) obj; RankEvalSpec other = (RankEvalSpec) obj;
return Objects.equals(specId, other.specId) && return Objects.equals(ratedRequests, other.ratedRequests) &&
Objects.equals(ratedRequests, other.ratedRequests) &&
Objects.equals(metric, other.metric); Objects.equals(metric, other.metric);
} }
@Override @Override
public final int hashCode() { public final int hashCode() {
return Objects.hash(specId, ratedRequests, metric); return Objects.hash(ratedRequests, metric);
} }
} }

View File

@ -112,7 +112,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
if (responseCounter.decrementAndGet() < 1) { if (responseCounter.decrementAndGet() < 1) {
// TODO add other statistics like micro/macro avg? // TODO add other statistics like micro/macro avg?
listener.onResponse( listener.onResponse(
new RankEvalResponse(task.getTaskId(), task.getEvaluator().combine(partialResults.values()), unknownDocs)); new RankEvalResponse(task.getEvaluator().combine(partialResults.values()), unknownDocs));
} }
} }

View File

@ -70,20 +70,18 @@ public class RankEvalRequestTests extends ESIntegTestCase {
List<String> indices = Arrays.asList(new String[] { "test" }); List<String> indices = Arrays.asList(new String[] { "test" });
List<String> types = Arrays.asList(new String[] { "testtype" }); List<String> types = Arrays.asList(new String[] { "testtype" });
String specId = randomAsciiOfLength(10);
List<RatedRequest> specifications = new ArrayList<>(); List<RatedRequest> specifications = new ArrayList<>();
SearchSourceBuilder testQuery = new SearchSourceBuilder(); SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder()); testQuery.query(new MatchAllQueryBuilder());
specifications.add(new RatedRequest("amsterdam_query", testQuery, indices, types, createRelevant("2", "3", "4", "5"))); specifications.add(new RatedRequest("amsterdam_query", testQuery, indices, types, createRelevant("2", "3", "4", "5")));
specifications.add(new RatedRequest("berlin_query", testQuery, indices, types, createRelevant("1"))); specifications.add(new RatedRequest("berlin_query", testQuery, indices, types, createRelevant("1")));
RankEvalSpec task = new RankEvalSpec(specId, specifications, new PrecisionAtN(10)); RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtN(10));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task); builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(specId, response.getSpecId());
assertEquals(1.0, response.getQualityLevel(), Double.MIN_VALUE); assertEquals(1.0, response.getQualityLevel(), Double.MIN_VALUE);
Set<Entry<String, Collection<RatedDocumentKey>>> entrySet = response.getUnknownDocs().entrySet(); Set<Entry<String, Collection<RatedDocumentKey>>> entrySet = response.getUnknownDocs().entrySet();
assertEquals(2, entrySet.size()); assertEquals(2, entrySet.size());
@ -104,7 +102,6 @@ public class RankEvalRequestTests extends ESIntegTestCase {
List<String> indices = Arrays.asList(new String[] { "test" }); List<String> indices = Arrays.asList(new String[] { "test" });
List<String> types = Arrays.asList(new String[] { "testtype" }); List<String> types = Arrays.asList(new String[] { "testtype" });
String specId = randomAsciiOfLength(10);
List<RatedRequest> specifications = new ArrayList<>(); List<RatedRequest> specifications = new ArrayList<>();
SearchSourceBuilder amsterdamQuery = new SearchSourceBuilder(); SearchSourceBuilder amsterdamQuery = new SearchSourceBuilder();
amsterdamQuery.query(new MatchAllQueryBuilder()); amsterdamQuery.query(new MatchAllQueryBuilder());
@ -114,7 +111,7 @@ public class RankEvalRequestTests extends ESIntegTestCase {
brokenQuery.query(brokenRangeQuery); brokenQuery.query(brokenRangeQuery);
specifications.add(new RatedRequest("broken_query", brokenQuery, indices, types, createRelevant("1"))); specifications.add(new RatedRequest("broken_query", brokenQuery, indices, types, createRelevant("1")));
RankEvalSpec task = new RankEvalSpec(specId, specifications, new PrecisionAtN(10)); RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtN(10));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task); builder.setRankEvalSpec(task);

View File

@ -47,7 +47,7 @@ public class RankEvalResponseTests extends ESTestCase {
} }
unknownDocs.put(randomAsciiOfLength(5), ids); unknownDocs.put(randomAsciiOfLength(5), ids);
} }
return new RankEvalResponse(randomAsciiOfLengthBetween(1, 10), randomDouble(), unknownDocs ); return new RankEvalResponse(randomDouble(), unknownDocs );
} }
public void testSerialization() throws IOException { public void testSerialization() throws IOException {

View File

@ -80,7 +80,6 @@ public class RankEvalSpecTests extends ESTestCase {
specs.add(RatedRequestsTests.createTestItem(indices, types)); specs.add(RatedRequestsTests.createTestItem(indices, types));
} }
String specId = randomAsciiOfLengthBetween(1, 10); // TODO we should reject zero length ids ...
RankedListQualityMetric metric; RankedListQualityMetric metric;
if (randomBoolean()) { if (randomBoolean()) {
metric = PrecisionAtNTests.createTestItem(); metric = PrecisionAtNTests.createTestItem();
@ -88,7 +87,7 @@ public class RankEvalSpecTests extends ESTestCase {
metric = DiscountedCumulativeGainAtTests.createTestItem(); metric = DiscountedCumulativeGainAtTests.createTestItem();
} }
RankEvalSpec testItem = new RankEvalSpec(specId, specs, metric); RankEvalSpec testItem = new RankEvalSpec(specs, metric);
XContentParser itemParser = XContentTestHelper.roundtrip(testItem); XContentParser itemParser = XContentTestHelper.roundtrip(testItem);

View File

@ -40,8 +40,7 @@
- do: - do:
rank_eval: rank_eval:
body: { body: {
"spec_id" : "cities_qa_queries",
"requests" : [ "requests" : [
{ {
"id": "amsterdam_query", "id": "amsterdam_query",
@ -60,7 +59,6 @@
"metric" : { "precisionatn": { "size": 10}} "metric" : { "precisionatn": { "size": 10}}
} }
- match: {rank_eval.spec_id: "cities_qa_queries"}
- match: {rank_eval.quality_level: 1} - match: {rank_eval.quality_level: 1}
- match: {rank_eval.unknown_docs.amsterdam_query: [ {"index": "foo", "type": "bar", "doc_id": "doc4"}]} - match: {rank_eval.unknown_docs.amsterdam_query: [ {"index": "foo", "type": "bar", "doc_id": "doc4"}]}
- match: {rank_eval.unknown_docs.berlin_query: [ {"index": "foo", "type": "bar", "doc_id": "doc4"}]} - match: {rank_eval.unknown_docs.berlin_query: [ {"index": "foo", "type": "bar", "doc_id": "doc4"}]}
@ -107,8 +105,7 @@
- do: - do:
rank_eval: rank_eval:
body: { body: {
"spec_id" : "cities_qa_queries",
"requests" : [ "requests" : [
{ {
"id": "amsterdam_query", "id": "amsterdam_query",
@ -126,14 +123,12 @@
"metric" : { "reciprocal_rank": {} } "metric" : { "reciprocal_rank": {} }
} }
- match: {rank_eval.spec_id: "cities_qa_queries"}
# average is (1/3 + 1/2)/2 = 5/12 ~ 0.41666666666666663 # average is (1/3 + 1/2)/2 = 5/12 ~ 0.41666666666666663
- match: {rank_eval.quality_level: 0.41666666666666663} - match: {rank_eval.quality_level: 0.41666666666666663}
- do: - do:
rank_eval: rank_eval:
body: { body: {
"spec_id" : "cities_qa_queries",
"requests" : [ "requests" : [
{ {
"id": "amsterdam_query", "id": "amsterdam_query",
@ -155,6 +150,5 @@
} }
} }
- match: {rank_eval.spec_id: "cities_qa_queries"}
# average is (0 + 1/2)/2 = 1/4 # average is (0 + 1/2)/2 = 1/4
- match: {rank_eval.quality_level: 0.25} - match: {rank_eval.quality_level: 0.25}

View File

@ -47,7 +47,6 @@
- do: - do:
rank_eval: rank_eval:
body: { body: {
"spec_id" : "dcg_qa_queries",
"requests" : [ "requests" : [
{ {
"id": "dcg_query", "id": "dcg_query",
@ -64,7 +63,6 @@
"metric" : { "dcg_at_n": { "size": 6}} "metric" : { "dcg_at_n": { "size": 6}}
} }
- match: {rank_eval.spec_id: "dcg_qa_queries"}
- match: {rank_eval.quality_level: 13.84826362927298} - match: {rank_eval.quality_level: 13.84826362927298}
# reverse the order in which the results are returned (less relevant docs first) # reverse the order in which the results are returned (less relevant docs first)
@ -72,7 +70,6 @@
- do: - do:
rank_eval: rank_eval:
body: { body: {
"spec_id" : "dcg_qa_queries",
"requests" : [ "requests" : [
{ {
"id": "dcg_query_reverse", "id": "dcg_query_reverse",
@ -89,7 +86,6 @@
"metric" : { "dcg_at_n": { "size": 6}} "metric" : { "dcg_at_n": { "size": 6}}
} }
- match: {rank_eval.spec_id: "dcg_qa_queries"}
- match: {rank_eval.quality_level: 10.29967439154499} - match: {rank_eval.quality_level: 10.29967439154499}
# if we mix both, we should get the average # if we mix both, we should get the average
@ -97,7 +93,6 @@
- do: - do:
rank_eval: rank_eval:
body: { body: {
"spec_id" : "dcg_qa_queries",
"requests" : [ "requests" : [
{ {
"id": "dcg_query", "id": "dcg_query",
@ -125,5 +120,4 @@
"metric" : { "dcg_at_n": { "size": 6}} "metric" : { "dcg_at_n": { "size": 6}}
} }
- match: {rank_eval.spec_id: "dcg_qa_queries"}
- match: {rank_eval.quality_level: 12.073969010408984} - match: {rank_eval.quality_level: 12.073969010408984}