RankEval: Adding details section to response (#20497)

In order to understand how well particular queries in a joint ranking evaluation 
request work we want to break down the overall metric into its components, each
contributed by a particular query. The response structure now has a
`details` section under which we can summarize this information. Each
sub-section is keyed by the query-id and currently only contains the partial
metric and the unknown_docs section for each query.
This commit is contained in:
Christoph Büscher 2016-09-22 12:17:10 +02:00 committed by GitHub
parent b8f5374fb4
commit 29402a28e0
23 changed files with 571 additions and 153 deletions

View File

@ -30,8 +30,8 @@ import org.elasticsearch.search.SearchHit;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -139,13 +139,13 @@ public class DiscountedCumulativeGainAt extends RankedListQualityMetric {
} }
@Override @Override
public EvalQueryQuality evaluate(SearchHit[] hits, List<RatedDocument> ratedDocs) { public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs) {
Map<RatedDocumentKey, RatedDocument> ratedDocsByKey = new HashMap<>(); Map<RatedDocumentKey, RatedDocument> ratedDocsByKey = new HashMap<>();
for (RatedDocument doc : ratedDocs) { for (RatedDocument doc : ratedDocs) {
ratedDocsByKey.put(doc.getKey(), doc); ratedDocsByKey.put(doc.getKey(), doc);
} }
Collection<RatedDocumentKey> unknownDocIds = new ArrayList<>(); List<RatedDocumentKey> unknownDocIds = new ArrayList<>();
List<Integer> ratings = new ArrayList<>(); List<Integer> ratings = new ArrayList<>();
for (int i = 0; (i < position && i < hits.length); i++) { for (int i = 0; (i < position && i < hits.length); i++) {
RatedDocumentKey id = new RatedDocumentKey(hits[i].getIndex(), hits[i].getType(), hits[i].getId()); RatedDocumentKey id = new RatedDocumentKey(hits[i].getIndex(), hits[i].getType(), hits[i].getId());
@ -156,24 +156,29 @@ public class DiscountedCumulativeGainAt extends RankedListQualityMetric {
unknownDocIds.add(id); unknownDocIds.add(id);
if (unknownDocRating != null) { if (unknownDocRating != null) {
ratings.add(unknownDocRating); ratings.add(unknownDocRating);
} else {
// we add null here so that the later computation knows this position had no rating
ratings.add(null);
} }
} }
} }
double dcg = computeDCG(ratings); double dcg = computeDCG(ratings);
if (normalize) { if (normalize) {
Collections.sort(ratings, Collections.reverseOrder()); Collections.sort(ratings, Comparator.nullsLast(Collections.reverseOrder()));
double idcg = computeDCG(ratings); double idcg = computeDCG(ratings);
dcg = dcg / idcg; dcg = dcg / idcg;
} }
return new EvalQueryQuality(dcg, unknownDocIds); return new EvalQueryQuality(taskId, dcg, unknownDocIds);
} }
private static double computeDCG(List<Integer> ratings) { private static double computeDCG(List<Integer> ratings) {
int rank = 1; int rank = 1;
double dcg = 0; double dcg = 0;
for (int rating : ratings) { for (Integer rating : ratings) {
if (rating != null) {
dcg += (Math.pow(2, rating) - 1) / ((Math.log(rank + 1) / LOG2)); dcg += (Math.pow(2, rating) - 1) / ((Math.log(rank + 1) / LOG2));
}
rank++; rank++;
} }
return dcg; return dcg;
@ -227,4 +232,6 @@ public class DiscountedCumulativeGainAt extends RankedListQualityMetric {
public final int hashCode() { public final int hashCode() {
return Objects.hash(position, normalize, unknownDocRating); return Objects.hash(position, normalize, unknownDocRating);
} }
// TODO maybe also add debugging breakdown here
} }

View File

@ -19,29 +19,107 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import java.util.Collection; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;;
/** /**
* Returned for each search specification. Summarizes the measured quality * This class represents the partial information from running the ranking evaluation metric on one
* metric for this search request and adds the document ids found that were in * request alone. It contains all information necessary to render the response for this part of the
* the search result but not annotated in the original request. * overall evaluation.
*/ */
public class EvalQueryQuality { public class EvalQueryQuality implements ToXContent, Writeable {
/** documents seen as result for one request that were not annotated.*/
private List<RatedDocumentKey> unknownDocs;
private String id;
private double qualityLevel; private double qualityLevel;
private MetricDetails optionalMetricDetails;
private Collection<RatedDocumentKey> unknownDocs; public EvalQueryQuality(String id, double qualityLevel, List<RatedDocumentKey> unknownDocs) {
this.id = id;
public EvalQueryQuality (double qualityLevel, Collection<RatedDocumentKey> unknownDocs) {
this.qualityLevel = qualityLevel;
this.unknownDocs = unknownDocs; this.unknownDocs = unknownDocs;
this.qualityLevel = qualityLevel;
} }
public Collection<RatedDocumentKey> getUnknownDocs() { public EvalQueryQuality(StreamInput in) throws IOException {
return unknownDocs; this(in.readString(), in.readDouble(), in.readList(RatedDocumentKey::new));
this.optionalMetricDetails = in.readOptionalNamedWriteable(MetricDetails.class);
}
public String getId() {
return id;
} }
public double getQualityLevel() { public double getQualityLevel() {
return qualityLevel; return qualityLevel;
} }
public List<RatedDocumentKey> getUnknownDocs() {
return Collections.unmodifiableList(this.unknownDocs);
}
public void addMetricDetails(MetricDetails breakdown) {
this.optionalMetricDetails = breakdown;
}
public MetricDetails getMetricDetails() {
return this.optionalMetricDetails;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(id);
out.writeDouble(qualityLevel);
out.writeVInt(unknownDocs.size());
for (RatedDocumentKey key : unknownDocs) {
key.writeTo(out);
}
out.writeOptionalNamedWriteable(this.optionalMetricDetails);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(id);
builder.field("quality_level", this.qualityLevel);
builder.startArray("unknown_docs");
for (RatedDocumentKey key : unknownDocs) {
key.toXContent(builder, params);
}
builder.endArray();
if (optionalMetricDetails != null) {
builder.startObject("metric_details");
optionalMetricDetails.toXContent(builder, params);
builder.endObject();
}
builder.endObject();
return builder;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
EvalQueryQuality other = (EvalQueryQuality) obj;
return Objects.equals(id, other.id) &&
Objects.equals(qualityLevel, other.qualityLevel) &&
Objects.equals(unknownDocs, other.unknownDocs) &&
Objects.equals(optionalMetricDetails, other.optionalMetricDetails);
}
@Override
public final int hashCode() {
return Objects.hash(id, qualityLevel, unknownDocs, optionalMetricDetails);
}
} }

View File

@ -0,0 +1,27 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContent;
public interface MetricDetails extends ToXContent, NamedWriteable {
}

View File

@ -120,7 +120,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
* @return precision at n for above {@link SearchResult} list. * @return precision at n for above {@link SearchResult} list.
**/ **/
@Override @Override
public EvalQueryQuality evaluate(SearchHit[] hits, List<RatedDocument> ratedDocs) { public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs) {
Collection<RatedDocumentKey> relevantDocIds = new ArrayList<>(); Collection<RatedDocumentKey> relevantDocIds = new ArrayList<>();
Collection<RatedDocumentKey> irrelevantDocIds = new ArrayList<>(); Collection<RatedDocumentKey> irrelevantDocIds = new ArrayList<>();
@ -134,7 +134,7 @@ public class PrecisionAtN extends RankedListQualityMetric {
int good = 0; int good = 0;
int bad = 0; int bad = 0;
Collection<RatedDocumentKey> unknownDocIds = new ArrayList<>(); List<RatedDocumentKey> unknownDocIds = new ArrayList<>();
for (int i = 0; (i < n && i < hits.length); i++) { for (int i = 0; (i < n && i < hits.length); i++) {
RatedDocumentKey hitKey = new RatedDocumentKey(hits[i].getIndex(), hits[i].getType(), hits[i].getId()); RatedDocumentKey hitKey = new RatedDocumentKey(hits[i].getIndex(), hits[i].getType(), hits[i].getId());
if (relevantDocIds.contains(hitKey)) { if (relevantDocIds.contains(hitKey)) {
@ -146,7 +146,9 @@ public class PrecisionAtN extends RankedListQualityMetric {
} }
} }
double precision = (double) good / (good + bad); double precision = (double) good / (good + bad);
return new EvalQueryQuality(precision, unknownDocIds); EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, precision, unknownDocIds);
evalQueryQuality.addMetricDetails(new PrecisionAtN.Breakdown(good, good + bad));
return evalQueryQuality;
} }
// TODO add abstraction that also works for other metrics // TODO add abstraction that also works for other metrics
@ -199,4 +201,66 @@ public class PrecisionAtN extends RankedListQualityMetric {
public final int hashCode() { public final int hashCode() {
return Objects.hash(n); return Objects.hash(n);
} }
public static class Breakdown implements MetricDetails {
public static final String DOCS_RETRIEVED_FIELD = "docs_retrieved";
public static final String RELEVANT_DOCS_RETRIEVED_FIELD = "relevant_docs_retrieved";
private int relevantRetrieved;
private int retrieved;
public Breakdown(int relevantRetrieved, int retrieved) {
this.relevantRetrieved = relevantRetrieved;
this.retrieved = retrieved;
}
public Breakdown(StreamInput in) throws IOException {
this.relevantRetrieved = in.readVInt();
this.retrieved = in.readVInt();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(RELEVANT_DOCS_RETRIEVED_FIELD, relevantRetrieved);
builder.field(DOCS_RETRIEVED_FIELD, retrieved);
return builder;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(relevantRetrieved);
out.writeVInt(retrieved);
}
@Override
public String getWriteableName() {
return NAME;
}
public int getRelevantRetrieved() {
return relevantRetrieved;
}
public int getRetrieved() {
return retrieved;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
PrecisionAtN.Breakdown other = (PrecisionAtN.Breakdown) obj;
return Objects.equals(relevantRetrieved, other.relevantRetrieved) &&
Objects.equals(retrieved, other.retrieved);
}
@Override
public final int hashCode() {
return Objects.hash(relevantRetrieved, retrieved);
}
}
} }

View File

@ -49,9 +49,15 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
*/ */
@Override @Override
public List<NamedWriteableRegistry.Entry> getNamedWriteables() { public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
List<NamedWriteableRegistry.Entry> metrics = new ArrayList<>(); List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
metrics.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, PrecisionAtN.NAME, PrecisionAtN::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, PrecisionAtN.NAME, PrecisionAtN::new));
metrics.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, ReciprocalRank.NAME, ReciprocalRank::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, ReciprocalRank.NAME, ReciprocalRank::new));
return metrics; namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGainAt.NAME,
DiscountedCumulativeGainAt::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, PrecisionAtN.NAME,
PrecisionAtN.Breakdown::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, ReciprocalRank.NAME,
ReciprocalRank.Breakdown::new));
return namedWriteables;
} }
} }

View File

@ -26,8 +26,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.Collections;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -47,41 +46,37 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
/**Average precision observed when issuing query intents with this specification.*/ /**Average precision observed when issuing query intents with this specification.*/
private double qualityLevel; private double qualityLevel;
/**Mapping from intent id to all documents seen for this intent that were not annotated.*/ /**Mapping from intent id to all documents seen for this intent that were not annotated.*/
private Map<String, Collection<RatedDocumentKey>> unknownDocs; private Map<String, EvalQueryQuality> details;
public RankEvalResponse() { public RankEvalResponse() {
} }
public RankEvalResponse(double qualityLevel, Map<String, Collection<RatedDocumentKey>> unknownDocs) { public RankEvalResponse(double qualityLevel, Map<String, EvalQueryQuality> partialResults) {
this.qualityLevel = qualityLevel; this.qualityLevel = qualityLevel;
this.unknownDocs = unknownDocs; this.details = partialResults;
} }
public double getQualityLevel() { public double getQualityLevel() {
return qualityLevel; return qualityLevel;
} }
public Map<String, Collection<RatedDocumentKey>> getUnknownDocs() { public Map<String, EvalQueryQuality> getPartialResults() {
return unknownDocs; return Collections.unmodifiableMap(details);
} }
@Override @Override
public String toString() { public String toString() {
return "RankEvalResponse, quality: " + qualityLevel + ", unknown docs: " + unknownDocs; return "RankEvalResponse, quality: " + qualityLevel + ", partial results: " + details;
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out); super.writeTo(out);
out.writeDouble(qualityLevel); out.writeDouble(qualityLevel);
out.writeVInt(unknownDocs.size()); out.writeVInt(details.size());
for (String queryId : unknownDocs.keySet()) { for (String queryId : details.keySet()) {
out.writeString(queryId); out.writeString(queryId);
Collection<RatedDocumentKey> collection = unknownDocs.get(queryId); details.get(queryId).writeTo(out);
out.writeVInt(collection.size());
for (RatedDocumentKey key : collection) {
key.writeTo(out);
}
} }
} }
@ -89,16 +84,12 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
public void readFrom(StreamInput in) throws IOException { public void readFrom(StreamInput in) throws IOException {
super.readFrom(in); super.readFrom(in);
this.qualityLevel = in.readDouble(); this.qualityLevel = in.readDouble();
int unknownDocumentSets = in.readVInt(); int partialResultSize = in.readVInt();
this.unknownDocs = new HashMap<>(unknownDocumentSets); this.details = new HashMap<>(partialResultSize);
for (int i = 0; i < unknownDocumentSets; i++) { for (int i = 0; i < partialResultSize; i++) {
String queryId = in.readString(); String queryId = in.readString();
int numberUnknownDocs = in.readVInt(); EvalQueryQuality partial = new EvalQueryQuality(in);
Collection<RatedDocumentKey> collection = new ArrayList<>(numberUnknownDocs); this.details.put(queryId, partial);
for (int d = 0; d < numberUnknownDocs; d++) {
collection.add(new RatedDocumentKey(in));
}
this.unknownDocs.put(queryId, collection);
} }
} }
@ -106,18 +97,9 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject("rank_eval"); builder.startObject("rank_eval");
builder.field("quality_level", qualityLevel); builder.field("quality_level", qualityLevel);
builder.startObject("unknown_docs"); builder.startObject("details");
for (String key : unknownDocs.keySet()) { for (String key : details.keySet()) {
Collection<RatedDocumentKey> keys = unknownDocs.get(key); details.get(key).toXContent(builder, params);
builder.startArray(key);
for (RatedDocumentKey docKey : keys) {
builder.startObject();
builder.field(RatedDocument.INDEX_FIELD.getPreferredName(), docKey.getIndex());
builder.field(RatedDocument.TYPE_FIELD.getPreferredName(), docKey.getType());
builder.field(RatedDocument.DOC_ID_FIELD.getPreferredName(), docKey.getDocID());
builder.endObject();
}
builder.endArray();
} }
builder.endObject(); builder.endObject();
builder.endObject(); builder.endObject();
@ -134,11 +116,11 @@ public class RankEvalResponse extends ActionResponse implements ToXContent {
} }
RankEvalResponse other = (RankEvalResponse) obj; RankEvalResponse other = (RankEvalResponse) obj;
return Objects.equals(qualityLevel, other.qualityLevel) && return Objects.equals(qualityLevel, other.qualityLevel) &&
Objects.equals(unknownDocs, other.unknownDocs); Objects.equals(details, other.details);
} }
@Override @Override
public final int hashCode() { public final int hashCode() {
return Objects.hash(getClass(), qualityLevel, unknownDocs); return Objects.hash(qualityLevel, details);
} }
} }

View File

@ -44,10 +44,12 @@ public abstract class RankedListQualityMetric extends ToXContentToBytes implemen
* Returns a single metric representing the ranking quality of a set of returned documents * Returns a single metric representing the ranking quality of a set of returned documents
* wrt. to a set of document Ids labeled as relevant for this search. * wrt. to a set of document Ids labeled as relevant for this search.
* *
* @param taskId the id of the query for which the ranking is currently evaluated
* @param hits the result hits as returned by some search * @param hits the result hits as returned by some search
* @param ratedDocs the documents that were ranked by human annotators for this query case
* @return some metric representing the quality of the result hit list wrt. to relevant doc ids. * @return some metric representing the quality of the result hit list wrt. to relevant doc ids.
* */ * */
public abstract EvalQueryQuality evaluate(SearchHit[] hits, List<RatedDocument> ratedDocs); public abstract EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs);
public static RankedListQualityMetric fromXContent(XContentParser parser, ParseFieldMatcherSupplier context) throws IOException { public static RankedListQualityMetric fromXContent(XContentParser parser, ParseFieldMatcherSupplier context) throws IOException {
RankedListQualityMetric rc; RankedListQualityMetric rc;

View File

@ -19,14 +19,16 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.support.ToXContentToBytes;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
public class RatedDocumentKey implements Writeable { public class RatedDocumentKey extends ToXContentToBytes implements Writeable {
private String docId; private String docId;
private String type; private String type;
@ -93,4 +95,14 @@ public class RatedDocumentKey implements Writeable {
public final int hashCode() { public final int hashCode() {
return Objects.hash(index, type, docId); return Objects.hash(index, type, docId);
} }
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(RatedDocument.INDEX_FIELD.getPreferredName(), index);
builder.field(RatedDocument.TYPE_FIELD.getPreferredName(), type);
builder.field(RatedDocument.DOC_ID_FIELD.getPreferredName(), docId);
builder.endObject();
return builder;
}
} }

View File

@ -30,7 +30,6 @@ import org.elasticsearch.search.SearchHit;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@ -114,7 +113,7 @@ public class ReciprocalRank extends RankedListQualityMetric {
* @return reciprocal Rank for above {@link SearchResult} list. * @return reciprocal Rank for above {@link SearchResult} list.
**/ **/
@Override @Override
public EvalQueryQuality evaluate(SearchHit[] hits, List<RatedDocument> ratedDocs) { public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs) {
Set<RatedDocumentKey> relevantDocIds = new HashSet<>(); Set<RatedDocumentKey> relevantDocIds = new HashSet<>();
Set<RatedDocumentKey> irrelevantDocIds = new HashSet<>(); Set<RatedDocumentKey> irrelevantDocIds = new HashSet<>();
for (RatedDocument doc : ratedDocs) { for (RatedDocument doc : ratedDocs) {
@ -125,7 +124,7 @@ public class ReciprocalRank extends RankedListQualityMetric {
} }
} }
Collection<RatedDocumentKey> unknownDocIds = new ArrayList<>(); List<RatedDocumentKey> unknownDocIds = new ArrayList<>();
int firstRelevant = -1; int firstRelevant = -1;
boolean found = false; boolean found = false;
for (int i = 0; i < hits.length; i++) { for (int i = 0; i < hits.length; i++) {
@ -141,7 +140,9 @@ public class ReciprocalRank extends RankedListQualityMetric {
} }
double reciprocalRank = (firstRelevant == -1) ? 0 : 1.0d / firstRelevant; double reciprocalRank = (firstRelevant == -1) ? 0 : 1.0d / firstRelevant;
return new EvalQueryQuality(reciprocalRank, unknownDocIds); EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, reciprocalRank, unknownDocIds);
evalQueryQuality.addMetricDetails(new Breakdown(firstRelevant));
return evalQueryQuality;
} }
@Override @Override
@ -189,4 +190,53 @@ public class ReciprocalRank extends RankedListQualityMetric {
public final int hashCode() { public final int hashCode() {
return Objects.hash(maxAcceptableRank); return Objects.hash(maxAcceptableRank);
} }
public static class Breakdown implements MetricDetails {
private int firstRelevantRank;
public Breakdown(int firstRelevantRank) {
this.firstRelevantRank = firstRelevantRank;
}
public Breakdown(StreamInput in) throws IOException {
this.firstRelevantRank = in.readVInt();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("first_relevant", firstRelevantRank);
return builder;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(firstRelevantRank);
}
@Override
public String getWriteableName() {
return NAME;
}
public int getFirstRelevantRank() {
return firstRelevantRank;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
ReciprocalRank.Breakdown other = (ReciprocalRank.Breakdown) obj;
return Objects.equals(firstRelevantRank, other.firstRelevantRank);
}
@Override
public final int hashCode() {
return Objects.hash(firstRelevantRank);
}
}
} }

View File

@ -88,7 +88,6 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
private RatedRequest specification; private RatedRequest specification;
private Map<String, EvalQueryQuality> partialResults; private Map<String, EvalQueryQuality> partialResults;
private RankEvalSpec task; private RankEvalSpec task;
private Map<String, Collection<RatedDocumentKey>> unknownDocs;
private AtomicInteger responseCounter; private AtomicInteger responseCounter;
public RankEvalActionListener(ActionListener<RankEvalResponse> listener, RankEvalSpec task, RatedRequest specification, public RankEvalActionListener(ActionListener<RankEvalResponse> listener, RankEvalSpec task, RatedRequest specification,
@ -98,21 +97,20 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
this.task = task; this.task = task;
this.specification = specification; this.specification = specification;
this.partialResults = partialResults; this.partialResults = partialResults;
this.unknownDocs = unknownDocs;
this.responseCounter = responseCounter; this.responseCounter = responseCounter;
} }
@Override @Override
public void onResponse(SearchResponse searchResponse) { public void onResponse(SearchResponse searchResponse) {
SearchHits hits = searchResponse.getHits(); SearchHits hits = searchResponse.getHits();
EvalQueryQuality queryQuality = task.getEvaluator().evaluate(hits.getHits(), specification.getRatedDocs()); EvalQueryQuality queryQuality = task.getEvaluator().evaluate(specification.getSpecId(), hits.getHits(),
specification.getRatedDocs());
partialResults.put(specification.getSpecId(), queryQuality); partialResults.put(specification.getSpecId(), queryQuality);
unknownDocs.put(specification.getSpecId(), queryQuality.getUnknownDocs());
if (responseCounter.decrementAndGet() < 1) { if (responseCounter.decrementAndGet() < 1) {
// TODO add other statistics like micro/macro avg? // TODO add other statistics like micro/macro avg?
listener.onResponse( listener.onResponse(
new RankEvalResponse(task.getEvaluator().combine(partialResults.values()), unknownDocs)); new RankEvalResponse(task.getEvaluator().combine(partialResults.values()), partialResults));
} }
} }

View File

@ -60,7 +60,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0))); hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0)));
} }
DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(6); DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(6);
assertEquals(13.84826362927298, dcg.evaluate(hits, rated).getQualityLevel(), 0.00001); assertEquals(13.84826362927298, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
/** /**
* Check with normalization: to get the maximal possible dcg, sort documents by relevance in descending order * Check with normalization: to get the maximal possible dcg, sort documents by relevance in descending order
@ -77,7 +77,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
* idcg = 14.595390756454922 (sum of last column) * idcg = 14.595390756454922 (sum of last column)
*/ */
dcg.setNormalize(true); dcg.setNormalize(true);
assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate(hits, rated).getQualityLevel(), 0.00001); assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
} }
/** /**
@ -89,26 +89,45 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
* 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721 * 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5 * 3 | 3 | 7.0 | 2.0 | 3.5
* 4 | n/a | n/a | n/a | n/a * 4 | n/a | n/a | n/a | n/a
* 5 | n/a | n/a | n/a | n/a * 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
* 6 | n/a | n/a | n/a | n/a * 6 | n/a | n/a | n/a | n/a
* *
* dcg = 13.84826362927298 (sum of last column) * dcg = 12.779642067948913 (sum of last column)
*/ */
public void testDCGAtSixMissingRatings() throws IOException, InterruptedException, ExecutionException { public void testDCGAtSixMissingRatings() throws IOException, InterruptedException, ExecutionException {
List<RatedDocument> rated = new ArrayList<>(); List<RatedDocument> rated = new ArrayList<>();
int[] relevanceRatings = new int[] { 3, 2, 3}; Integer[] relevanceRatings = new Integer[] { 3, 2, 3, null, 1};
InternalSearchHit[] hits = new InternalSearchHit[6]; InternalSearchHit[] hits = new InternalSearchHit[6];
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
if (i < relevanceRatings.length) { if (i < relevanceRatings.length) {
if (relevanceRatings[i] != null) {
rated.add(new RatedDocument("index", "type", Integer.toString(i), relevanceRatings[i])); rated.add(new RatedDocument("index", "type", Integer.toString(i), relevanceRatings[i]));
} }
}
hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap()); hits[i] = new InternalSearchHit(i, Integer.toString(i), new Text("type"), Collections.emptyMap());
hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0))); hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0)));
} }
DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(6); DiscountedCumulativeGainAt dcg = new DiscountedCumulativeGainAt(6);
EvalQueryQuality result = dcg.evaluate(hits, rated); EvalQueryQuality result = dcg.evaluate("id", hits, rated);
assertEquals(12.392789260714371, result.getQualityLevel(), 0.00001); assertEquals(12.779642067948913, result.getQualityLevel(), 0.00001);
assertEquals(3, result.getUnknownDocs().size()); assertEquals(2, result.getUnknownDocs().size());
/**
* Check with normalization: to get the maximal possible dcg, sort documents by relevance in descending order
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* -------------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0  | 7.0
* 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
* 3 | 2 | 3.0 | 2.0  | 1.5
* 4 | 1 | 1.0 | 2.321928094887362   | 0.43067655807339
* 5 | n.a | n.a | n.a.  | n.a.
* 6 | n.a | n.a | n.a  | n.a
*
* idcg = 13.347184833073591 (sum of last column)
*/
dcg.setNormalize(true);
assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
} }
public void testParseFromXContent() throws IOException { public void testParseFromXContent() throws IOException {
@ -131,7 +150,7 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
} }
public void testXContentRoundtrip() throws IOException { public void testXContentRoundtrip() throws IOException {
DiscountedCumulativeGainAt testItem = createTestItem(); DiscountedCumulativeGainAt testItem = createTestItem();
XContentParser itemParser = XContentTestHelper.roundtrip(testItem); XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem);
itemParser.nextToken(); itemParser.nextToken();
itemParser.nextToken(); itemParser.nextToken();
DiscountedCumulativeGainAt parsedItem = DiscountedCumulativeGainAt.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT); DiscountedCumulativeGainAt parsedItem = DiscountedCumulativeGainAt.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);

View File

@ -0,0 +1,105 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class EvalQueryQualityTests extends ESTestCase {
private static NamedWriteableRegistry namedWritableRegistry = new NamedWriteableRegistry(new RankEvalPlugin().getNamedWriteables());
public static EvalQueryQuality randomEvalQueryQuality() {
List<RatedDocumentKey> unknownDocs = new ArrayList<>();
int numberOfUnknownDocs = randomInt(5);
for (int i = 0; i < numberOfUnknownDocs; i++) {
unknownDocs.add(RatedDocumentKeyTests.createRandomRatedDocumentKey());
}
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(randomAsciiOfLength(10), randomDoubleBetween(0.0, 1.0, true), unknownDocs);
if (randomBoolean()) {
// TODO randomize this
evalQueryQuality.addMetricDetails(new PrecisionAtN.Breakdown(1, 5));
}
return evalQueryQuality;
}
private static EvalQueryQuality copy(EvalQueryQuality original) throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput()) {
original.writeTo(output);
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWritableRegistry)) {
return new EvalQueryQuality(in);
}
}
}
public void testSerialization() throws IOException {
EvalQueryQuality original = randomEvalQueryQuality();
EvalQueryQuality deserialized = copy(original);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
EvalQueryQuality testItem = randomEvalQueryQuality();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
copy(testItem));
}
private static EvalQueryQuality mutateTestItem(EvalQueryQuality original) {
String id = original.getId();
double qualityLevel = original.getQualityLevel();
List<RatedDocumentKey> unknownDocs = original.getUnknownDocs();
MetricDetails breakdown = original.getMetricDetails();
switch (randomIntBetween(0, 3)) {
case 0:
id = id + "_";
break;
case 1:
qualityLevel = qualityLevel + 0.1;
break;
case 2:
unknownDocs = new ArrayList<>(unknownDocs);
unknownDocs.add(RatedDocumentKeyTests.createRandomRatedDocumentKey());
break;
case 3:
if (breakdown == null) {
breakdown = new PrecisionAtN.Breakdown(1, 5);
} else {
breakdown = null;
}
break;
default:
throw new IllegalStateException("The test should only allow three parameters mutated");
}
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(id, qualityLevel, unknownDocs);
evalQueryQuality.addMetricDetails(breakdown);
return evalQueryQuality;
}
}

View File

@ -46,7 +46,10 @@ public class PrecisionAtNTests extends ESTestCase {
InternalSearchHit[] hits = new InternalSearchHit[1]; InternalSearchHit[] hits = new InternalSearchHit[1];
hits[0] = new InternalSearchHit(0, "0", new Text("testtype"), Collections.emptyMap()); hits[0] = new InternalSearchHit(0, "0", new Text("testtype"), Collections.emptyMap());
hits[0].shard(new SearchShardTarget("testnode", new Index("test", "uuid"), 0)); hits[0].shard(new SearchShardTarget("testnode", new Index("test", "uuid"), 0));
assertEquals(1, (new PrecisionAtN(5)).evaluate(hits, rated).getQualityLevel(), 0.00001); EvalQueryQuality evaluated = (new PrecisionAtN(5)).evaluate("id", hits, rated);
assertEquals(1, evaluated.getQualityLevel(), 0.00001);
assertEquals(1, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(1, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
} }
public void testPrecisionAtFiveIgnoreOneResult() throws IOException, InterruptedException, ExecutionException { public void testPrecisionAtFiveIgnoreOneResult() throws IOException, InterruptedException, ExecutionException {
@ -61,7 +64,10 @@ public class PrecisionAtNTests extends ESTestCase {
hits[i] = new InternalSearchHit(i, i+"", new Text("testtype"), Collections.emptyMap()); hits[i] = new InternalSearchHit(i, i+"", new Text("testtype"), Collections.emptyMap());
hits[i].shard(new SearchShardTarget("testnode", new Index("test", "uuid"), 0)); hits[i].shard(new SearchShardTarget("testnode", new Index("test", "uuid"), 0));
} }
assertEquals((double) 4 / 5, (new PrecisionAtN(5)).evaluate(hits, rated).getQualityLevel(), 0.00001); EvalQueryQuality evaluated = (new PrecisionAtN(5)).evaluate("id", hits, rated);
assertEquals((double) 4 / 5, evaluated.getQualityLevel(), 0.00001);
assertEquals(4, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
} }
/** /**
@ -82,7 +88,10 @@ public class PrecisionAtNTests extends ESTestCase {
} }
PrecisionAtN precisionAtN = new PrecisionAtN(5); PrecisionAtN precisionAtN = new PrecisionAtN(5);
precisionAtN.setRelevantRatingThreshhold(2); precisionAtN.setRelevantRatingThreshhold(2);
assertEquals((double) 3 / 5, precisionAtN.evaluate(hits, rated).getQualityLevel(), 0.00001); EvalQueryQuality evaluated = precisionAtN.evaluate("id", hits, rated);
assertEquals((double) 3 / 5, evaluated.getQualityLevel(), 0.00001);
assertEquals(3, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(5, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
} }
public void testPrecisionAtFiveCorrectIndex() throws IOException, InterruptedException, ExecutionException { public void testPrecisionAtFiveCorrectIndex() throws IOException, InterruptedException, ExecutionException {
@ -97,7 +106,10 @@ public class PrecisionAtNTests extends ESTestCase {
hits[i] = new InternalSearchHit(i, i+"", new Text("testtype"), Collections.emptyMap()); hits[i] = new InternalSearchHit(i, i+"", new Text("testtype"), Collections.emptyMap());
hits[i].shard(new SearchShardTarget("testnode", new Index("test", "uuid"), 0)); hits[i].shard(new SearchShardTarget("testnode", new Index("test", "uuid"), 0));
} }
assertEquals((double) 2 / 3, (new PrecisionAtN(5)).evaluate(hits, rated).getQualityLevel(), 0.00001); EvalQueryQuality evaluated = (new PrecisionAtN(5)).evaluate("id", hits, rated);
assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
} }
public void testPrecisionAtFiveCorrectType() throws IOException, InterruptedException, ExecutionException { public void testPrecisionAtFiveCorrectType() throws IOException, InterruptedException, ExecutionException {
@ -112,7 +124,10 @@ public class PrecisionAtNTests extends ESTestCase {
hits[i] = new InternalSearchHit(i, i+"", new Text("testtype"), Collections.emptyMap()); hits[i] = new InternalSearchHit(i, i+"", new Text("testtype"), Collections.emptyMap());
hits[i].shard(new SearchShardTarget("testnode", new Index("test", "uuid"), 0)); hits[i].shard(new SearchShardTarget("testnode", new Index("test", "uuid"), 0));
} }
assertEquals((double) 2 / 3, (new PrecisionAtN(5)).evaluate(hits, rated).getQualityLevel(), 0.00001); EvalQueryQuality evaluated = (new PrecisionAtN(5)).evaluate("id", hits, rated);
assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
assertEquals(3, ((PrecisionAtN.Breakdown) evaluated.getMetricDetails()).getRetrieved());
} }
public void testParseFromXContent() throws IOException { public void testParseFromXContent() throws IOException {
@ -129,9 +144,9 @@ public class PrecisionAtNTests extends ESTestCase {
public void testCombine() { public void testCombine() {
PrecisionAtN metric = new PrecisionAtN(); PrecisionAtN metric = new PrecisionAtN();
Vector<EvalQueryQuality> partialResults = new Vector<>(3); Vector<EvalQueryQuality> partialResults = new Vector<>(3);
partialResults.add(new EvalQueryQuality(0.1, emptyList())); partialResults.add(new EvalQueryQuality("a", 0.1, emptyList()));
partialResults.add(new EvalQueryQuality(0.2, emptyList())); partialResults.add(new EvalQueryQuality("b", 0.2, emptyList()));
partialResults.add(new EvalQueryQuality(0.6, emptyList())); partialResults.add(new EvalQueryQuality("c", 0.6, emptyList()));
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE); assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
} }
@ -142,7 +157,7 @@ public class PrecisionAtNTests extends ESTestCase {
public void testXContentRoundtrip() throws IOException { public void testXContentRoundtrip() throws IOException {
PrecisionAtN testItem = createTestItem(); PrecisionAtN testItem = createTestItem();
XContentParser itemParser = XContentTestHelper.roundtrip(testItem); XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem);
itemParser.nextToken(); itemParser.nextToken();
itemParser.nextToken(); itemParser.nextToken();
PrecisionAtN parsedItem = PrecisionAtN.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT); PrecisionAtN parsedItem = PrecisionAtN.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);

View File

@ -83,14 +83,14 @@ public class RankEvalRequestTests extends ESIntegTestCase {
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(1.0, response.getQualityLevel(), Double.MIN_VALUE); assertEquals(1.0, response.getQualityLevel(), Double.MIN_VALUE);
Set<Entry<String, Collection<RatedDocumentKey>>> entrySet = response.getUnknownDocs().entrySet(); Set<Entry<String, EvalQueryQuality>> entrySet = response.getPartialResults().entrySet();
assertEquals(2, entrySet.size()); assertEquals(2, entrySet.size());
for (Entry<String, Collection<RatedDocumentKey>> entry : entrySet) { for (Entry<String, EvalQueryQuality> entry : entrySet) {
if (entry.getKey() == "amsterdam_query") { if (entry.getKey() == "amsterdam_query") {
assertEquals(2, entry.getValue().size()); assertEquals(2, entry.getValue().getUnknownDocs().size());
} }
if (entry.getKey() == "berlin_query") { if (entry.getKey() == "berlin_query") {
assertEquals(5, entry.getValue().size()); assertEquals(5, entry.getValue().getUnknownDocs().size());
} }
} }
} }

View File

@ -29,7 +29,6 @@ import org.elasticsearch.test.ESTestCase;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -37,17 +36,18 @@ import java.util.Map;
public class RankEvalResponseTests extends ESTestCase { public class RankEvalResponseTests extends ESTestCase {
private static RankEvalResponse createRandomResponse() { private static RankEvalResponse createRandomResponse() {
Map<String, Collection<RatedDocumentKey>> unknownDocs = new HashMap<>(); int numberOfRequests = randomIntBetween(0, 5);
int numberOfSets = randomIntBetween(0, 5); Map<String, EvalQueryQuality> partials = new HashMap<>(numberOfRequests);
for (int i = 0; i < numberOfSets; i++) { for (int i = 0; i < numberOfRequests; i++) {
List<RatedDocumentKey> ids = new ArrayList<>(); String id = randomAsciiOfLengthBetween(3, 10);
int numberOfUnknownDocs = randomIntBetween(0, 5); int numberOfUnknownDocs = randomIntBetween(0, 5);
List<RatedDocumentKey> unknownDocs = new ArrayList<>(numberOfUnknownDocs);
for (int d = 0; d < numberOfUnknownDocs; d++) { for (int d = 0; d < numberOfUnknownDocs; d++) {
ids.add(new RatedDocumentKey(randomAsciiOfLength(5), randomAsciiOfLength(5), randomAsciiOfLength(5))); unknownDocs.add(RatedDocumentKeyTests.createRandomRatedDocumentKey());
} }
unknownDocs.put(randomAsciiOfLength(5), ids); partials.put(id, new EvalQueryQuality(id, randomDoubleBetween(0.0, 1.0, true), unknownDocs));
} }
return new RankEvalResponse(randomDouble(), unknownDocs ); return new RankEvalResponse(randomDouble(), partials);
} }
public void testSerialization() throws IOException { public void testSerialization() throws IOException {

View File

@ -89,7 +89,7 @@ public class RankEvalSpecTests extends ESTestCase {
RankEvalSpec testItem = new RankEvalSpec(specs, metric); RankEvalSpec testItem = new RankEvalSpec(specs, metric);
XContentParser itemParser = XContentTestHelper.roundtrip(testItem); XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem);
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT); QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT);
RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext, RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext,

View File

@ -30,7 +30,14 @@ import org.elasticsearch.test.ESTestCase;
import java.io.IOException; import java.io.IOException;
public class XContentTestHelper { import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
public class RankEvalTestHelper {
public static XContentParser roundtrip(ToXContentToBytes testItem) throws IOException { public static XContentParser roundtrip(ToXContentToBytes testItem) throws IOException {
XContentBuilder builder = XContentFactory.contentBuilder(ESTestCase.randomFrom(XContentType.values())); XContentBuilder builder = XContentFactory.contentBuilder(ESTestCase.randomFrom(XContentType.values()));
@ -42,4 +49,20 @@ public class XContentTestHelper {
XContentParser itemParser = XContentHelper.createParser(shuffled.bytes()); XContentParser itemParser = XContentHelper.createParser(shuffled.bytes());
return itemParser; return itemParser;
} }
public static void testHashCodeAndEquals(Object testItem, Object mutation, Object secondCopy) {
assertFalse("testItem is equal to null", testItem.equals(null));
assertFalse("testItem is equal to incompatible type", testItem.equals(""));
assertTrue("testItem is not equal to self", testItem.equals(testItem));
assertThat("same testItem's hashcode returns different values if called multiple times", testItem.hashCode(),
equalTo(testItem.hashCode()));
assertThat("different testItem should not be equal", mutation, not(equalTo(testItem)));
assertNotSame("testItem copy is not same as original", testItem, secondCopy);
assertTrue("testItem is not equal to its copy", testItem.equals(secondCopy));
assertTrue("equals is not symmetric", secondCopy.equals(testItem));
assertThat("testItem copy's hashcode is different from original hashcode", secondCopy.hashCode(),
equalTo(testItem.hashCode()));
}
} }

View File

@ -23,45 +23,43 @@ import org.elasticsearch.test.ESTestCase;
import java.io.IOException; import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
public class RatedDocumentKeyTests extends ESTestCase { public class RatedDocumentKeyTests extends ESTestCase {
public void testEqualsAndHash() throws IOException { static RatedDocumentKey createRandomRatedDocumentKey() {
String index = randomAsciiOfLengthBetween(0, 10); String index = randomAsciiOfLengthBetween(0, 10);
String type = randomAsciiOfLengthBetween(0, 10); String type = randomAsciiOfLengthBetween(0, 10);
String docId = randomAsciiOfLengthBetween(0, 10); String docId = randomAsciiOfLengthBetween(0, 10);
return new RatedDocumentKey(index, type, docId);
}
RatedDocumentKey testItem = new RatedDocumentKey(index, type, docId); public RatedDocumentKey createRandomTestItem() {
return createRandomRatedDocumentKey();
}
assertFalse("key is equal to null", testItem.equals(null)); public RatedDocumentKey mutateTestItem(RatedDocumentKey original) {
assertFalse("key is equal to incompatible type", testItem.equals("")); String index = original.getIndex();
assertTrue("key is not equal to self", testItem.equals(testItem)); String type = original.getType();
assertThat("same key's hashcode returns different values if called multiple times", testItem.hashCode(), String docId = original.getDocID();
equalTo(testItem.hashCode()));
RatedDocumentKey mutation;
switch (randomIntBetween(0, 2)) { switch (randomIntBetween(0, 2)) {
case 0: case 0:
mutation = new RatedDocumentKey(testItem.getIndex() + "_foo", testItem.getType(), testItem.getDocID()); index = index + "_";
break; break;
case 1: case 1:
mutation = new RatedDocumentKey(testItem.getIndex(), testItem.getType() + "_foo", testItem.getDocID()); type = type + "_";
break; break;
case 2: case 2:
mutation = new RatedDocumentKey(testItem.getIndex(), testItem.getType(), testItem.getDocID() + "_foo"); docId = docId + "_";
break; break;
default: default:
throw new IllegalStateException("The test should only allow three parameters mutated"); throw new IllegalStateException("The test should only allow three parameters mutated");
} }
return new RatedDocumentKey(index, type, docId);
}
assertThat("different keys should not be equal", mutation, not(equalTo(testItem))); public void testEqualsAndHash() throws IOException {
RatedDocumentKey testItem = createRandomRatedDocumentKey();
RatedDocumentKey secondEqualKey = new RatedDocumentKey(index, type, docId); RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
assertTrue("key is not equal to its copy", testItem.equals(secondEqualKey)); new RatedDocumentKey(testItem.getIndex(), testItem.getType(), testItem.getDocID()));
assertTrue("equals is not symmetric", secondEqualKey.equals(testItem));
assertThat("key copy's hashcode is different from original hashcode", secondEqualKey.hashCode(),
equalTo(testItem.hashCode()));
} }
} }

View File

@ -27,7 +27,7 @@ import java.io.IOException;
public class RatedDocumentTests extends ESTestCase { public class RatedDocumentTests extends ESTestCase {
public static RatedDocument createTestItem() { public static RatedDocument createRatedDocument() {
String index = randomAsciiOfLength(10); String index = randomAsciiOfLength(10);
String type = randomAsciiOfLength(10); String type = randomAsciiOfLength(10);
String docId = randomAsciiOfLength(10); String docId = randomAsciiOfLength(10);
@ -37,8 +37,8 @@ public class RatedDocumentTests extends ESTestCase {
} }
public void testXContentParsing() throws IOException { public void testXContentParsing() throws IOException {
RatedDocument testItem = createTestItem(); RatedDocument testItem = createRatedDocument();
XContentParser itemParser = XContentTestHelper.roundtrip(testItem); XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem);
RatedDocument parsedItem = RatedDocument.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT); RatedDocument parsedItem = RatedDocument.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
assertNotSame(testItem, parsedItem); assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem); assertEquals(testItem, parsedItem);

View File

@ -76,7 +76,7 @@ public class RatedRequestsTests extends ESTestCase {
List<RatedDocument> ratedDocs = new ArrayList<>(); List<RatedDocument> ratedDocs = new ArrayList<>();
int size = randomIntBetween(0, 2); int size = randomIntBetween(0, 2);
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
ratedDocs.add(RatedDocumentTests.createTestItem()); ratedDocs.add(RatedDocumentTests.createRatedDocument());
} }
return new RatedRequest(specId, testRequest, indices, types, ratedDocs); return new RatedRequest(specId, testRequest, indices, types, ratedDocs);
@ -96,7 +96,7 @@ public class RatedRequestsTests extends ESTestCase {
} }
RatedRequest testItem = createTestItem(indices, types); RatedRequest testItem = createTestItem(indices, types);
XContentParser itemParser = XContentTestHelper.roundtrip(testItem); XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem);
itemParser.nextToken(); itemParser.nextToken();
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT); QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT);

View File

@ -63,17 +63,19 @@ public class ReciprocalRankTests extends ESTestCase {
} }
int rankAtFirstRelevant = relevantAt + 1; int rankAtFirstRelevant = relevantAt + 1;
EvalQueryQuality evaluation = reciprocalRank.evaluate(hits, ratedDocs); EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, ratedDocs);
if (rankAtFirstRelevant <= maxRank) { if (rankAtFirstRelevant <= maxRank) {
assertEquals(1.0 / rankAtFirstRelevant, evaluation.getQualityLevel(), Double.MIN_VALUE); assertEquals(1.0 / rankAtFirstRelevant, evaluation.getQualityLevel(), Double.MIN_VALUE);
assertEquals(rankAtFirstRelevant, ((ReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
// check that if we lower maxRank by one, we don't find any result and get 0.0 quality level // check that if we lower maxRank by one, we don't find any result and get 0.0 quality level
reciprocalRank = new ReciprocalRank(rankAtFirstRelevant - 1); reciprocalRank = new ReciprocalRank(rankAtFirstRelevant - 1);
evaluation = reciprocalRank.evaluate(hits, ratedDocs); evaluation = reciprocalRank.evaluate("id", hits, ratedDocs);
assertEquals(0.0, evaluation.getQualityLevel(), Double.MIN_VALUE); assertEquals(0.0, evaluation.getQualityLevel(), Double.MIN_VALUE);
} else { } else {
assertEquals(0.0, evaluation.getQualityLevel(), Double.MIN_VALUE); assertEquals(0.0, evaluation.getQualityLevel(), Double.MIN_VALUE);
assertEquals(-1, ((ReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
} }
} }
@ -95,8 +97,9 @@ public class ReciprocalRankTests extends ESTestCase {
} }
} }
EvalQueryQuality evaluation = reciprocalRank.evaluate(hits, ratedDocs); EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, ratedDocs);
assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE); assertEquals(1.0 / (relevantAt + 1), evaluation.getQualityLevel(), Double.MIN_VALUE);
assertEquals(relevantAt + 1, ((ReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
} }
/** /**
@ -119,15 +122,17 @@ public class ReciprocalRankTests extends ESTestCase {
ReciprocalRank reciprocalRank = new ReciprocalRank(); ReciprocalRank reciprocalRank = new ReciprocalRank();
reciprocalRank.setRelevantRatingThreshhold(2); reciprocalRank.setRelevantRatingThreshhold(2);
assertEquals((double) 1 / 3, reciprocalRank.evaluate(hits, rated).getQualityLevel(), 0.00001); EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated);
assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001);
assertEquals(3, ((ReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
} }
public void testCombine() { public void testCombine() {
ReciprocalRank reciprocalRank = new ReciprocalRank(); ReciprocalRank reciprocalRank = new ReciprocalRank();
Vector<EvalQueryQuality> partialResults = new Vector<>(3); Vector<EvalQueryQuality> partialResults = new Vector<>(3);
partialResults.add(new EvalQueryQuality(0.5, emptyList())); partialResults.add(new EvalQueryQuality("id1", 0.5, emptyList()));
partialResults.add(new EvalQueryQuality(1.0, emptyList())); partialResults.add(new EvalQueryQuality("id2", 1.0, emptyList()));
partialResults.add(new EvalQueryQuality(0.75, emptyList())); partialResults.add(new EvalQueryQuality("id3", 0.75, emptyList()));
assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE); assertEquals(0.75, reciprocalRank.combine(partialResults), Double.MIN_VALUE);
} }
@ -139,7 +144,7 @@ public class ReciprocalRankTests extends ESTestCase {
hits[i].shard(new SearchShardTarget("testnode", new Index("test", "uuid"), 0)); hits[i].shard(new SearchShardTarget("testnode", new Index("test", "uuid"), 0));
} }
List<RatedDocument> ratedDocs = new ArrayList<>(); List<RatedDocument> ratedDocs = new ArrayList<>();
EvalQueryQuality evaluation = reciprocalRank.evaluate(hits, ratedDocs); EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, ratedDocs);
assertEquals(0.0, evaluation.getQualityLevel(), Double.MIN_VALUE); assertEquals(0.0, evaluation.getQualityLevel(), Double.MIN_VALUE);
} }
@ -147,7 +152,7 @@ public class ReciprocalRankTests extends ESTestCase {
int position = randomIntBetween(0, 1000); int position = randomIntBetween(0, 1000);
ReciprocalRank testItem = new ReciprocalRank(position); ReciprocalRank testItem = new ReciprocalRank(position);
XContentParser itemParser = XContentTestHelper.roundtrip(testItem); XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem);
itemParser.nextToken(); itemParser.nextToken();
itemParser.nextToken(); itemParser.nextToken();
ReciprocalRank parsedItem = ReciprocalRank.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT); ReciprocalRank parsedItem = ReciprocalRank.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
@ -155,5 +160,4 @@ public class ReciprocalRankTests extends ESTestCase {
assertEquals(testItem, parsedItem); assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode()); assertEquals(testItem.hashCode(), parsedItem.hashCode());
} }
} }

View File

@ -60,8 +60,13 @@
} }
- match: {rank_eval.quality_level: 1} - match: {rank_eval.quality_level: 1}
- match: {rank_eval.unknown_docs.amsterdam_query: [ {"_index": "foo", "_type": "bar", "_id": "doc4"}]} - match: {rank_eval.details.amsterdam_query.quality_level: 1.0}
- match: {rank_eval.unknown_docs.berlin_query: [ {"_index": "foo", "_type": "bar", "_id": "doc4"}]} - match: {rank_eval.details.amsterdam_query.unknown_docs: [ {"_index": "foo", "_type": "bar", "_id": "doc4"}]}
- match: {rank_eval.details.amsterdam_query.metric_details: {"relevant_docs_retrieved": 2, "docs_retrieved": 2}}
- match: {rank_eval.details.berlin_query.quality_level: 1.0}
- match: {rank_eval.details.berlin_query.unknown_docs: [ {"_index": "foo", "_type": "bar", "_id": "doc4"}]}
- match: {rank_eval.details.berlin_query.metric_details: {"relevant_docs_retrieved": 1, "docs_retrieved": 1}}
--- ---
"Reciprocal Rank": "Reciprocal Rank":
@ -125,6 +130,13 @@
# average is (1/3 + 1/2)/2 = 5/12 ~ 0.41666666666666663 # average is (1/3 + 1/2)/2 = 5/12 ~ 0.41666666666666663
- match: {rank_eval.quality_level: 0.41666666666666663} - match: {rank_eval.quality_level: 0.41666666666666663}
- match: {rank_eval.details.amsterdam_query.quality_level: 0.3333333333333333}
- match: {rank_eval.details.amsterdam_query.metric_details: {"first_relevant": 3}}
- match: {rank_eval.details.amsterdam_query.unknown_docs: [ {"_index": "foo", "_type": "bar", "_id": "doc2"},
{"_index": "foo", "_type": "bar", "_id": "doc3"} ]}
- match: {rank_eval.details.berlin_query.quality_level: 0.5}
- match: {rank_eval.details.berlin_query.metric_details: {"first_relevant": 2}}
- match: {rank_eval.details.berlin_query.unknown_docs: [ {"_index": "foo", "_type": "bar", "_id": "doc1"}]}
- do: - do:
rank_eval: rank_eval:
@ -145,6 +157,7 @@
], ],
"metric" : { "metric" : {
"reciprocal_rank": { "reciprocal_rank": {
# the following will make the first query have a quality value of 0.0
"max_acceptable_rank" : 2 "max_acceptable_rank" : 2
} }
} }
@ -152,3 +165,10 @@
# average is (0 + 1/2)/2 = 1/4 # average is (0 + 1/2)/2 = 1/4
- match: {rank_eval.quality_level: 0.25} - match: {rank_eval.quality_level: 0.25}
- match: {rank_eval.details.amsterdam_query.quality_level: 0}
- match: {rank_eval.details.amsterdam_query.metric_details: {"first_relevant": -1}}
- match: {rank_eval.details.amsterdam_query.unknown_docs: [ {"_index": "foo", "_type": "bar", "_id": "doc2"},
{"_index": "foo", "_type": "bar", "_id": "doc3"} ]}
- match: {rank_eval.details.berlin_query.quality_level: 0.5}
- match: {rank_eval.details.berlin_query.metric_details: {"first_relevant": 2}}
- match: {rank_eval.details.berlin_query.unknown_docs: [ {"_index": "foo", "_type": "bar", "_id": "doc1"}]}

View File

@ -64,6 +64,8 @@
} }
- match: {rank_eval.quality_level: 13.84826362927298} - match: {rank_eval.quality_level: 13.84826362927298}
- match: {rank_eval.details.dcg_query.quality_level: 13.84826362927298}
- match: {rank_eval.details.dcg_query.unknown_docs: [ ]}
# reverse the order in which the results are returned (less relevant docs first) # reverse the order in which the results are returned (less relevant docs first)
@ -87,6 +89,8 @@
} }
- match: {rank_eval.quality_level: 10.29967439154499} - match: {rank_eval.quality_level: 10.29967439154499}
- match: {rank_eval.details.dcg_query_reverse.quality_level: 10.29967439154499}
- match: {rank_eval.details.dcg_query_reverse.unknown_docs: [ ]}
# if we mix both, we should get the average # if we mix both, we should get the average
@ -121,3 +125,7 @@
} }
- match: {rank_eval.quality_level: 12.073969010408984} - match: {rank_eval.quality_level: 12.073969010408984}
- match: {rank_eval.details.dcg_query.quality_level: 13.84826362927298}
- match: {rank_eval.details.dcg_query.unknown_docs: [ ]}
- match: {rank_eval.details.dcg_query_reverse.quality_level: 10.29967439154499}
- match: {rank_eval.details.dcg_query_reverse.unknown_docs: [ ]}