Merge pull request #20442 from cbuescher/rankEval-removeTopLevelId

RankEval: Remove top level `spec_id`
This commit is contained in:
Christoph Büscher 2016-09-19 12:34:14 +02:00 committed by GitHub
commit 10d465f946
8 changed files with 15 additions and 60 deletions

View File

@ -44,8 +44,6 @@ import java.util.Objects;
**/
//TODO instead of just returning averages over complete results, think of other statistics, micro avg, macro avg, partial results
public class RankEvalResponse extends ActionResponse implements ToXContent {
/**ID of QA specification this result was generated for.*/
private String specId;
/**Average precision observed when issuing query intents with this specification.*/
private double qualityLevel;
/**Mapping from intent id to all documents seen for this intent that were not annotated.*/
@ -54,17 +52,11 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
public RankEvalResponse() {
}
public RankEvalResponse(String specId, double qualityLevel, Map<String, Collection<RatedDocumentKey>> unknownDocs) {
this.specId = specId;
public RankEvalResponse(double qualityLevel, Map<String, Collection<RatedDocumentKey>> unknownDocs) {
this.qualityLevel = qualityLevel;
this.unknownDocs = unknownDocs;
}
public String getSpecId() {
return specId;
}
public double getQualityLevel() {
return qualityLevel;
}
@ -75,13 +67,12 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
@Override
public String toString() {
return "RankEvalResponse, ID :[" + specId + "], quality: " + qualityLevel + ", unknown docs: " + unknownDocs;
return "RankEvalResponse, quality: " + qualityLevel + ", unknown docs: " + unknownDocs;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(specId);
out.writeDouble(qualityLevel);
out.writeVInt(unknownDocs.size());
for (String queryId : unknownDocs.keySet()) {
@ -97,7 +88,6 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
this.specId = in.readString();
this.qualityLevel = in.readDouble();
int unknownDocumentSets = in.readVInt();
this.unknownDocs = new HashMap<>(unknownDocumentSets);
@ -115,7 +105,6 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject("rank_eval");
builder.field("spec_id", specId);
builder.field("quality_level", qualityLevel);
builder.startObject("unknown_docs");
for (String key : unknownDocs.keySet()) {
@ -144,13 +133,12 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
return false;
}
RankEvalResponse other = (RankEvalResponse) obj;
return Objects.equals(specId, other.specId) &&
Objects.equals(qualityLevel, other.qualityLevel) &&
return Objects.equals(qualityLevel, other.qualityLevel) &&
Objects.equals(unknownDocs, other.unknownDocs);
}
@Override
public final int hashCode() {
return Objects.hash(getClass(), specId, qualityLevel, unknownDocs);
return Objects.hash(getClass(), qualityLevel, unknownDocs);
}
}

View File

@ -47,15 +47,12 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
private Collection<RatedRequest> ratedRequests = new ArrayList<>();
/** Definition of the quality metric, e.g. precision at N */
private RankedListQualityMetric metric;
/** a unique id for the whole QA task */
private String specId;
public RankEvalSpec() {
// TODO think if no args ctor is okay
}
public RankEvalSpec(String specId, Collection<RatedRequest> specs, RankedListQualityMetric metric) {
this.specId = specId;
public RankEvalSpec(Collection<RatedRequest> specs, RankedListQualityMetric metric) {
this.ratedRequests = specs;
this.metric = metric;
}
@ -67,7 +64,6 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
ratedRequests.add(new RatedRequest(in));
}
metric = in.readNamedWriteable(RankedListQualityMetric.class);
specId = in.readString();
}
@Override
@ -77,21 +73,12 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
spec.writeTo(out);
}
out.writeNamedWriteable(metric);
out.writeString(specId);
}
public void setEval(RankedListQualityMetric eval) {
this.metric = eval;
}
public void setTaskId(String taskId) {
this.specId = taskId;
}
public String getTaskId() {
return this.specId;
}
/** Returns the precision at n configuration (containing level of n to consider).*/
public RankedListQualityMetric getEvaluator() {
return metric;
@ -112,13 +99,11 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
this.ratedRequests = specifications;
}
private static final ParseField SPECID_FIELD = new ParseField("spec_id");
private static final ParseField METRIC_FIELD = new ParseField("metric");
private static final ParseField REQUESTS_FIELD = new ParseField("requests");
private static final ObjectParser<RankEvalSpec, RankEvalContext> PARSER = new ObjectParser<>("rank_eval", RankEvalSpec::new);
static {
PARSER.declareString(RankEvalSpec::setTaskId, SPECID_FIELD);
PARSER.declareObject(RankEvalSpec::setEvaluator, (p, c) -> {
try {
return RankedListQualityMetric.fromXContent(p, c);
@ -138,7 +123,6 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(SPECID_FIELD.getPreferredName(), this.specId);
builder.startArray(REQUESTS_FIELD.getPreferredName());
for (RatedRequest spec : this.ratedRequests) {
spec.toXContent(builder, params);
@ -162,13 +146,12 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
return false;
}
RankEvalSpec other = (RankEvalSpec) obj;
return Objects.equals(specId, other.specId) &&
Objects.equals(ratedRequests, other.ratedRequests) &&
return Objects.equals(ratedRequests, other.ratedRequests) &&
Objects.equals(metric, other.metric);
}
@Override
public final int hashCode() {
return Objects.hash(specId, ratedRequests, metric);
return Objects.hash(ratedRequests, metric);
}
}

View File

@ -112,7 +112,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
if (responseCounter.decrementAndGet() < 1) {
// TODO add other statistics like micro/macro avg?
listener.onResponse(
new RankEvalResponse(task.getTaskId(), task.getEvaluator().combine(partialResults.values()), unknownDocs));
new RankEvalResponse(task.getEvaluator().combine(partialResults.values()), unknownDocs));
}
}

View File

@ -70,20 +70,18 @@ public class RankEvalRequestTests extends ESIntegTestCase {
List<String> indices = Arrays.asList(new String[] { "test" });
List<String> types = Arrays.asList(new String[] { "testtype" });
String specId = randomAsciiOfLength(10);
List<RatedRequest> specifications = new ArrayList<>();
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
specifications.add(new RatedRequest("amsterdam_query", testQuery, indices, types, createRelevant("2", "3", "4", "5")));
specifications.add(new RatedRequest("berlin_query", testQuery, indices, types, createRelevant("1")));
RankEvalSpec task = new RankEvalSpec(specId, specifications, new PrecisionAtN(10));
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtN(10));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(specId, response.getSpecId());
assertEquals(1.0, response.getQualityLevel(), Double.MIN_VALUE);
Set<Entry<String, Collection<RatedDocumentKey>>> entrySet = response.getUnknownDocs().entrySet();
assertEquals(2, entrySet.size());
@ -104,7 +102,6 @@ public class RankEvalRequestTests extends ESIntegTestCase {
List<String> indices = Arrays.asList(new String[] { "test" });
List<String> types = Arrays.asList(new String[] { "testtype" });
String specId = randomAsciiOfLength(10);
List<RatedRequest> specifications = new ArrayList<>();
SearchSourceBuilder amsterdamQuery = new SearchSourceBuilder();
amsterdamQuery.query(new MatchAllQueryBuilder());
@ -114,7 +111,7 @@ public class RankEvalRequestTests extends ESIntegTestCase {
brokenQuery.query(brokenRangeQuery);
specifications.add(new RatedRequest("broken_query", brokenQuery, indices, types, createRelevant("1")));
RankEvalSpec task = new RankEvalSpec(specId, specifications, new PrecisionAtN(10));
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtN(10));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);

View File

@ -47,7 +47,7 @@ public class RankEvalResponseTests extends ESTestCase {
}
unknownDocs.put(randomAsciiOfLength(5), ids);
}
return new RankEvalResponse(randomAsciiOfLengthBetween(1, 10), randomDouble(), unknownDocs );
return new RankEvalResponse(randomDouble(), unknownDocs );
}
public void testSerialization() throws IOException {

View File

@ -80,7 +80,6 @@ public class RankEvalSpecTests extends ESTestCase {
specs.add(RatedRequestsTests.createTestItem(indices, types));
}
String specId = randomAsciiOfLengthBetween(1, 10); // TODO we should reject zero length ids ...
RankedListQualityMetric metric;
if (randomBoolean()) {
metric = PrecisionAtNTests.createTestItem();
@ -88,7 +87,7 @@ public class RankEvalSpecTests extends ESTestCase {
metric = DiscountedCumulativeGainAtTests.createTestItem();
}
RankEvalSpec testItem = new RankEvalSpec(specId, specs, metric);
RankEvalSpec testItem = new RankEvalSpec(specs, metric);
XContentParser itemParser = XContentTestHelper.roundtrip(testItem);

View File

@ -41,7 +41,6 @@
- do:
rank_eval:
body: {
"spec_id" : "cities_qa_queries",
"requests" : [
{
"id": "amsterdam_query",
@ -60,7 +59,6 @@
"metric" : { "precisionatn": { "size": 10}}
}
- match: {rank_eval.spec_id: "cities_qa_queries"}
- match: {rank_eval.quality_level: 1}
- match: {rank_eval.unknown_docs.amsterdam_query: [ {"index": "foo", "type": "bar", "doc_id": "doc4"}]}
- match: {rank_eval.unknown_docs.berlin_query: [ {"index": "foo", "type": "bar", "doc_id": "doc4"}]}
@ -108,7 +106,6 @@
- do:
rank_eval:
body: {
"spec_id" : "cities_qa_queries",
"requests" : [
{
"id": "amsterdam_query",
@ -126,14 +123,12 @@
"metric" : { "reciprocal_rank": {} }
}
- match: {rank_eval.spec_id: "cities_qa_queries"}
# average is (1/3 + 1/2)/2 = 5/12 ~ 0.41666666666666663
- match: {rank_eval.quality_level: 0.41666666666666663}
- do:
rank_eval:
body: {
"spec_id" : "cities_qa_queries",
"requests" : [
{
"id": "amsterdam_query",
@ -155,6 +150,5 @@
}
}
- match: {rank_eval.spec_id: "cities_qa_queries"}
# average is (0 + 1/2)/2 = 1/4
- match: {rank_eval.quality_level: 0.25}

View File

@ -47,7 +47,6 @@
- do:
rank_eval:
body: {
"spec_id" : "dcg_qa_queries",
"requests" : [
{
"id": "dcg_query",
@ -64,7 +63,6 @@
"metric" : { "dcg_at_n": { "size": 6}}
}
- match: {rank_eval.spec_id: "dcg_qa_queries"}
- match: {rank_eval.quality_level: 13.84826362927298}
# reverse the order in which the results are returned (less relevant docs first)
@ -72,7 +70,6 @@
- do:
rank_eval:
body: {
"spec_id" : "dcg_qa_queries",
"requests" : [
{
"id": "dcg_query_reverse",
@ -89,7 +86,6 @@
"metric" : { "dcg_at_n": { "size": 6}}
}
- match: {rank_eval.spec_id: "dcg_qa_queries"}
- match: {rank_eval.quality_level: 10.29967439154499}
# if we mix both, we should get the average
@ -97,7 +93,6 @@
- do:
rank_eval:
body: {
"spec_id" : "dcg_qa_queries",
"requests" : [
{
"id": "dcg_query",
@ -125,5 +120,4 @@
"metric" : { "dcg_at_n": { "size": 6}}
}
- match: {rank_eval.spec_id: "dcg_qa_queries"}
- match: {rank_eval.quality_level: 12.073969010408984}