Improving and cleaning up tests

Removing the unnecessary RankEvalTestHelper, making use of the common test infra
in ESTestCase, also hardening a few of the classes by making more fields final.
This commit is contained in:
Christoph Büscher 2017-11-14 19:26:32 +01:00 committed by Christoph Büscher
parent 5c65a59369
commit e278c1d17d
25 changed files with 450 additions and 636 deletions

View File

@ -22,7 +22,7 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchHit;
@ -35,22 +35,25 @@ import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
public class DiscountedCumulativeGain implements RankedListQualityMetric {
public class DiscountedCumulativeGain implements EvaluationMetric {
/** If set to true, the dcg will be normalized (ndcg) */
private boolean normalize;
private final boolean normalize;
/**
* If set to, this will be the rating for docs the user hasn't supplied an
* explicit rating for
*/
private Integer unknownDocRating;
private final Integer unknownDocRating;
public static final String NAME = "dcg";
private static final double LOG2 = Math.log(2.0);
public DiscountedCumulativeGain() {
this(false, null);
}
/**
@ -82,13 +85,6 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
return NAME;
}
/**
* If set to true, the dcg will be normalized (ndcg)
*/
public void setNormalize(boolean normalize) {
this.normalize = normalize;
}
/**
* check whether this metric computes only dcg or "normalized" ndcg
*/
@ -96,13 +92,6 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
return this.normalize;
}
/**
* the rating for docs the user hasn't supplied an explicit rating for
*/
public void setUnknownDocRating(int unknownDocRating) {
this.unknownDocRating = unknownDocRating;
}
/**
* get the rating used for unrated documents
*/
@ -118,10 +107,10 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
List<RatedSearchHit> ratedHits = joinHitsWithRatings(hits, ratedDocs);
List<Integer> ratingsInSearchHits = new ArrayList<>(ratedHits.size());
for (RatedSearchHit hit : ratedHits) {
// unknownDocRating might be null, which means it will be unrated
// docs are ignored in the dcg calculation
// we still need to add them as a placeholder so the rank of the
// subsequent ratings is correct
// unknownDocRating might be null, which means it will be unrated docs are
// ignored in the dcg calculation
// we still need to add them as a placeholder so the rank of the subsequent
// ratings is correct
ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating));
}
double dcg = computeDCG(ratingsInSearchHits);
@ -151,12 +140,15 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
private static final ObjectParser<DiscountedCumulativeGain, Void> PARSER = new ObjectParser<>(
"dcg_at", () -> new DiscountedCumulativeGain());
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at",
args -> {
Boolean normalized = (Boolean) args[0];
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1]);
});
static {
PARSER.declareBoolean(DiscountedCumulativeGain::setNormalize, NORMALIZE_FIELD);
PARSER.declareInt(DiscountedCumulativeGain::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD);
PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD);
PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD);
}
public static DiscountedCumulativeGain fromXContent(XContentParser parser) {
@ -193,6 +185,4 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
public final int hashCode() {
return Objects.hash(normalize, unknownDocRating);
}
// TODO maybe also add debugging breakdown here
}

View File

@ -1,106 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
public class DocumentKey implements Writeable, ToXContentObject {
private String docId;
private String index;
void setIndex(String index) {
this.index = index;
}
void setDocId(String docId) {
this.docId = docId;
}
public DocumentKey(String index, String docId) {
if (Strings.isNullOrEmpty(index)) {
throw new IllegalArgumentException("Index must be set for each rated document");
}
if (Strings.isNullOrEmpty(docId)) {
throw new IllegalArgumentException("DocId must be set for each rated document");
}
this.index = index;
this.docId = docId;
}
public DocumentKey(StreamInput in) throws IOException {
this.index = in.readString();
this.docId = in.readString();
}
public String getIndex() {
return index;
}
public String getDocID() {
return docId;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(index);
out.writeString(docId);
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
DocumentKey other = (DocumentKey) obj;
return Objects.equals(index, other.index) && Objects.equals(docId, other.docId);
}
@Override
public final int hashCode() {
return Objects.hash(index, docId);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(RatedDocument.INDEX_FIELD.getPreferredName(), index);
builder.field(RatedDocument.DOC_ID_FIELD.getPreferredName(), docId);
builder.endObject();
return builder;
}
@Override
public String toString() {
return Strings.toString(this);
}
}

View File

@ -24,6 +24,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
import java.io.IOException;
import java.util.ArrayList;
@ -91,8 +92,11 @@ public class EvalQueryQuality implements ToXContent, Writeable {
builder.startObject(id);
builder.field("quality_level", this.qualityLevel);
builder.startArray("unknown_docs");
for (DocumentKey key : RankedListQualityMetric.filterUnknownDocuments(hits)) {
key.toXContent(builder, params);
for (DocumentKey key : EvaluationMetric.filterUnknownDocuments(hits)) {
builder.startObject();
builder.field(RatedDocument.INDEX_FIELD.getPreferredName(), key.getIndex());
builder.field(RatedDocument.DOC_ID_FIELD.getPreferredName(), key.getDocId());
builder.endObject();
}
builder.endArray();
builder.startArray("hits");

View File

@ -24,7 +24,9 @@ import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import java.io.IOException;
import java.util.ArrayList;
@ -35,13 +37,14 @@ import java.util.Optional;
import java.util.stream.Collectors;
/**
* Classes implementing this interface provide a means to compute the quality of a result list returned by some search.
* Implementations of {@link EvaluationMetric} need to provide a way to compute the quality metric for
* a result list returned by some search (@link {@link SearchHits}) and a list of rated documents.
*/
public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
public interface EvaluationMetric extends ToXContent, NamedWriteable {
/**
* Returns a single metric representing the ranking quality of a set of returned
* documents wrt. to a set of document Ids labeled as relevant for this search.
* documents wrt. to a set of document ids labeled as relevant for this search.
*
* @param taskId
* the id of the query for which the ranking is currently evaluated
@ -55,15 +58,15 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
*/
EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs);
static RankedListQualityMetric fromXContent(XContentParser parser) throws IOException {
RankedListQualityMetric rc;
static EvaluationMetric fromXContent(XContentParser parser) throws IOException {
EvaluationMetric rc;
Token token = parser.nextToken();
if (token != XContentParser.Token.FIELD_NAME) {
throw new ParsingException(parser.getTokenLocation(), "[_na] missing required metric name");
}
String metricName = parser.currentName();
// TODO maybe switch to using a plugable registry later?
// TODO switch to using a plugable registry
switch (metricName) {
case PrecisionAtK.NAME:
rc = PrecisionAtK.fromXContent(parser);
@ -101,13 +104,19 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
return ratedSearchHits;
}
/**
* filter @link {@link RatedSearchHit} that don't have a rating
*/
static List<DocumentKey> filterUnknownDocuments(List<RatedSearchHit> ratedHits) {
// join hits with rated documents
List<DocumentKey> unknownDocs = ratedHits.stream().filter(hit -> hit.getRating().isPresent() == false)
.map(hit -> new DocumentKey(hit.getSearchHit().getIndex(), hit.getSearchHit().getId())).collect(Collectors.toList());
return unknownDocs;
}
/**
* how evaluation metrics for particular search queries get combined for the overall evaluation score.
* Defaults to averaging over the partial results.
*/
default double combine(Collection<EvalQueryQuality> partialResults) {
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
}

View File

@ -22,7 +22,7 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchHit;
@ -34,49 +34,45 @@ import java.util.Optional;
import javax.naming.directory.SearchResult;
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
/**
* Evaluate mean reciprocal rank. By default documents with a rating equal or bigger
* than 1 are considered to be "relevant" for the reciprocal rank calculation.
* This value can be changes using the "relevant_rating_threshold" parameter.
* Evaluates using mean reciprocal rank. By default documents with a rating
* equal or bigger than 1 are considered to be "relevant" for the reciprocal
* rank calculation. This value can be changes using the
* "relevant_rating_threshold" parameter.
*/
public class MeanReciprocalRank implements RankedListQualityMetric {
public class MeanReciprocalRank implements EvaluationMetric {
private static final int DEFAULT_RATING_THRESHOLD = 1;
public static final String NAME = "mean_reciprocal_rank";
/** ratings equal or above this value will be considered relevant. */
private int relevantRatingThreshhold = 1;
private final int relevantRatingThreshhold;
/**
* Initializes maxAcceptableRank with 10
*/
public MeanReciprocalRank() {
// use defaults
this(DEFAULT_RATING_THRESHOLD);
}
public MeanReciprocalRank(StreamInput in) throws IOException {
this.relevantRatingThreshhold = in.readVInt();
}
@Override
public String getWriteableName() {
return NAME;
}
/**
* Sets the rating threshold above which ratings are considered to be
* "relevant" for this metric.
*/
public void setRelevantRatingThreshhold(int threshold) {
public MeanReciprocalRank(int threshold) {
if (threshold < 0) {
throw new IllegalArgumentException(
"Relevant rating threshold for precision must be positive integer.");
}
this.relevantRatingThreshhold = threshold;
}
@Override
public String getWriteableName() {
return NAME;
}
/**
* Return the rating threshold above which ratings are considered to be
* "relevant" for this metric. Defaults to 1.
@ -119,13 +115,19 @@ public class MeanReciprocalRank implements RankedListQualityMetric {
out.writeVInt(relevantRatingThreshhold);
}
private static final ParseField RELEVANT_RATING_FIELD = new ParseField(
"relevant_rating_threshold");
private static final ObjectParser<MeanReciprocalRank, Void> PARSER = new ObjectParser<>(
"reciprocal_rank", () -> new MeanReciprocalRank());
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
private static final ConstructingObjectParser<MeanReciprocalRank, Void> PARSER = new ConstructingObjectParser<>("reciprocal_rank",
args -> {
Integer optionalThreshold = (Integer) args[0];
if (optionalThreshold == null) {
return new MeanReciprocalRank();
} else {
return new MeanReciprocalRank(optionalThreshold);
}
});
static {
PARSER.declareInt(MeanReciprocalRank::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD);
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
}
public static MeanReciprocalRank fromXContent(XContentParser parser) {

View File

@ -22,7 +22,7 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchHit;
@ -34,7 +34,8 @@ import java.util.Optional;
import javax.naming.directory.SearchResult;
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
/**
* Evaluate Precision of the search results. Documents without a rating are
@ -42,15 +43,12 @@ import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsW
* considered to be "relevant" for the precision calculation. This value can be
* changes using the "relevant_rating_threshold" parameter.
*/
public class PrecisionAtK implements RankedListQualityMetric {
public class PrecisionAtK implements EvaluationMetric {
public static final String NAME = "precision";
private static final ParseField RELEVANT_RATING_FIELD = new ParseField(
"relevant_rating_threshold");
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled");
private static final ObjectParser<PrecisionAtK, Void> PARSER = new ObjectParser<>(NAME,
PrecisionAtK::new);
/**
* This setting controls how unlabeled documents in the search hits are
@ -58,29 +56,47 @@ public class PrecisionAtK implements RankedListQualityMetric {
* as true or false positives. Set to 'false', they are treated as false
* positives.
*/
private boolean ignoreUnlabeled = false;
private final boolean ignoreUnlabeled;
/** ratings equal or above this value will be considered relevant. */
private int relevantRatingThreshhold = 1;
private final int relevantRatingThreshhold;
public PrecisionAtK(int threshold, boolean ignoreUnlabeled) {
if (threshold < 0) {
throw new IllegalArgumentException(
"Relevant rating threshold for precision must be positive integer.");
}
this.relevantRatingThreshhold = threshold;
this.ignoreUnlabeled = ignoreUnlabeled;
}
public PrecisionAtK() {
// needed for supplier in parser
this(1, false);
}
private static final ConstructingObjectParser<PrecisionAtK, Void> PARSER = new ConstructingObjectParser<>(NAME,
args -> {
Integer threshHold = (Integer) args[0];
Boolean ignoreUnlabeled = (Boolean) args[1];
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
ignoreUnlabeled == null ? false : ignoreUnlabeled);
});
static {
PARSER.declareInt(PrecisionAtK::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD);
PARSER.declareBoolean(PrecisionAtK::setIgnoreUnlabeled, IGNORE_UNLABELED_FIELD);
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
PARSER.declareBoolean(optionalConstructorArg(), IGNORE_UNLABELED_FIELD);
}
public PrecisionAtK(StreamInput in) throws IOException {
relevantRatingThreshhold = in.readOptionalVInt();
ignoreUnlabeled = in.readOptionalBoolean();
PrecisionAtK(StreamInput in) throws IOException {
relevantRatingThreshhold = in.readVInt();
ignoreUnlabeled = in.readBoolean();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalVInt(relevantRatingThreshhold);
out.writeOptionalBoolean(ignoreUnlabeled);
out.writeVInt(relevantRatingThreshhold);
out.writeBoolean(ignoreUnlabeled);
}
@Override
@ -88,18 +104,6 @@ public class PrecisionAtK implements RankedListQualityMetric {
return NAME;
}
/**
* Sets the rating threshold above which ratings are considered to be
* "relevant" for this metric.
*/
public void setRelevantRatingThreshhold(int threshold) {
if (threshold < 0) {
throw new IllegalArgumentException(
"Relevant rating threshold for precision must be positive integer.");
}
this.relevantRatingThreshhold = threshold;
}
/**
* Return the rating threshold above which ratings are considered to be
* "relevant" for this metric. Defaults to 1.
@ -108,13 +112,6 @@ public class PrecisionAtK implements RankedListQualityMetric {
return relevantRatingThreshhold;
}
/**
* Sets the 'ìgnore_unlabeled' parameter
*/
public void setIgnoreUnlabeled(boolean ignoreUnlabeled) {
this.ignoreUnlabeled = ignoreUnlabeled;
}
/**
* Gets the 'ìgnore_unlabeled' parameter
*/

View File

@ -56,24 +56,19 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
}
/**
* Returns parsers for {@link NamedWriteable} this plugin will use over the
* transport protocol.
*
* Returns parsers for {@link NamedWriteable} objects that this plugin sends over the transport protocol.
* @see NamedWriteableRegistry
*/
@Override
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class,
PrecisionAtK.NAME, PrecisionAtK::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class,
MeanReciprocalRank.NAME, MeanReciprocalRank::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class,
DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, PrecisionAtK.NAME,
PrecisionAtK.Breakdown::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class,
MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, PrecisionAtK.NAME, PrecisionAtK.Breakdown::new));
namedWriteables
.add(new NamedWriteableRegistry.Entry(MetricDetails.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new));
return namedWriteables;
}
}

View File

@ -34,7 +34,9 @@ import org.elasticsearch.script.Script;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
@ -53,9 +55,9 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
* Collection of query specifications, that is e.g. search request templates
* to use for query translation.
*/
private Collection<RatedRequest> ratedRequests = new ArrayList<>();
private final List<RatedRequest> ratedRequests;
/** Definition of the quality metric, e.g. precision at N */
private RankedListQualityMetric metric;
private final EvaluationMetric metric;
/** Maximum number of requests to execute in parallel. */
private int maxConcurrentSearches = MAX_CONCURRENT_SEARCHES;
/** Default max number of requests. */
@ -63,7 +65,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
/** optional: Templates to base test requests on */
private Map<String, Script> templates = new HashMap<>();
public RankEvalSpec(Collection<RatedRequest> ratedRequests, RankedListQualityMetric metric,
public RankEvalSpec(List<RatedRequest> ratedRequests, EvaluationMetric metric,
Collection<ScriptWithId> templates) {
if (ratedRequests == null || ratedRequests.size() < 1) {
throw new IllegalStateException(
@ -92,7 +94,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
}
}
public RankEvalSpec(Collection<RatedRequest> ratedRequests, RankedListQualityMetric metric) {
public RankEvalSpec(List<RatedRequest> ratedRequests, EvaluationMetric metric) {
this(ratedRequests, metric, null);
}
@ -102,7 +104,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
for (int i = 0; i < specSize; i++) {
ratedRequests.add(new RatedRequest(in));
}
metric = in.readNamedWriteable(RankedListQualityMetric.class);
metric = in.readNamedWriteable(EvaluationMetric.class);
int size = in.readVInt();
for (int i = 0; i < size; i++) {
String key = in.readString();
@ -128,13 +130,13 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
}
/** Returns the metric to use for quality evaluation.*/
public RankedListQualityMetric getMetric() {
public EvaluationMetric getMetric() {
return metric;
}
/** Returns a list of intent to query translation specifications to evaluate. */
public Collection<RatedRequest> getRatedRequests() {
return ratedRequests;
public List<RatedRequest> getRatedRequests() {
return Collections.unmodifiableList(ratedRequests);
}
/** Returns the template to base test requests on. */
@ -160,8 +162,8 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<RankEvalSpec, Void> PARSER =
new ConstructingObjectParser<>("rank_eval",
a -> new RankEvalSpec((Collection<RatedRequest>) a[0],
(RankedListQualityMetric) a[1], (Collection<ScriptWithId>) a[2]));
a -> new RankEvalSpec((List<RatedRequest>) a[0],
(EvaluationMetric) a[1], (Collection<ScriptWithId>) a[2]));
static {
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> {
@ -169,7 +171,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
} , REQUESTS_FIELD);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
try {
return RankedListQualityMetric.fromXContent(p);
return EvaluationMetric.fromXContent(p);
} catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
}

View File

@ -33,13 +33,24 @@ import java.io.IOException;
import java.util.Objects;
/**
* A document ID and its rating for the query QA use case.
* Represents a document (specified by its _index/_id) and its corresponding rating
* with respect to a specific search query.
* <p>
* Json structure in a request:
* <pre>
* {
* "_index": "my_index",
* "_id": "doc1",
* "rating": 0
* }
* </pre>
*
*/
public class RatedDocument implements Writeable, ToXContentObject {
public static final ParseField RATING_FIELD = new ParseField("rating");
public static final ParseField DOC_ID_FIELD = new ParseField("_id");
public static final ParseField INDEX_FIELD = new ParseField("_index");
static final ParseField RATING_FIELD = new ParseField("rating");
static final ParseField DOC_ID_FIELD = new ParseField("_id");
static final ParseField INDEX_FIELD = new ParseField("_index");
private static final ConstructingObjectParser<RatedDocument, Void> PARSER = new ConstructingObjectParser<>("rated_document",
a -> new RatedDocument((String) a[0], (String) a[1], (Integer) a[2]));
@ -50,23 +61,19 @@ public class RatedDocument implements Writeable, ToXContentObject {
PARSER.declareInt(ConstructingObjectParser.constructorArg(), RATING_FIELD);
}
private int rating;
private DocumentKey key;
private final int rating;
private final DocumentKey key;
public RatedDocument(String index, String docId, int rating) {
this(new DocumentKey(index, docId), rating);
}
public RatedDocument(StreamInput in) throws IOException {
this.key = new DocumentKey(in);
this.rating = in.readVInt();
}
public RatedDocument(DocumentKey ratedDocumentKey, int rating) {
this.key = ratedDocumentKey;
public RatedDocument(String index, String id, int rating) {
this.key = new DocumentKey(index, id);
this.rating = rating;
}
RatedDocument(StreamInput in) throws IOException {
this.key = new DocumentKey(in.readString(), in.readString());
this.rating = in.readVInt();
}
public DocumentKey getKey() {
return this.key;
}
@ -76,7 +83,7 @@ public class RatedDocument implements Writeable, ToXContentObject {
}
public String getDocID() {
return key.getDocID();
return key.getDocId();
}
public int getRating() {
@ -85,11 +92,12 @@ public class RatedDocument implements Writeable, ToXContentObject {
@Override
public void writeTo(StreamOutput out) throws IOException {
this.key.writeTo(out);
out.writeString(key.getIndex());
out.writeString(key.getDocId());
out.writeVInt(rating);
}
public static RatedDocument fromXContent(XContentParser parser) {
static RatedDocument fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
@ -97,7 +105,7 @@ public class RatedDocument implements Writeable, ToXContentObject {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(INDEX_FIELD.getPreferredName(), key.getIndex());
builder.field(DOC_ID_FIELD.getPreferredName(), key.getDocID());
builder.field(DOC_ID_FIELD.getPreferredName(), key.getDocId());
builder.field(RATING_FIELD.getPreferredName(), rating);
builder.endObject();
return builder;
@ -124,4 +132,55 @@ public class RatedDocument implements Writeable, ToXContentObject {
public final int hashCode() {
return Objects.hash(key, rating);
}
/**
* a joint document key consisting of the documents index and id
*/
static class DocumentKey {
private final String docId;
private final String index;
DocumentKey(String index, String docId) {
if (Strings.isNullOrEmpty(index)) {
throw new IllegalArgumentException("Index must be set for each rated document");
}
if (Strings.isNullOrEmpty(docId)) {
throw new IllegalArgumentException("DocId must be set for each rated document");
}
this.index = index;
this.docId = docId;
}
String getIndex() {
return index;
}
String getDocId() {
return docId;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
DocumentKey other = (DocumentKey) obj;
return Objects.equals(index, other.index) && Objects.equals(docId, other.docId);
}
@Override
public final int hashCode() {
return Objects.hash(index, docId);
}
@Override
public String toString() {
return "{\"_index\":\"" + index + "\",\"_id\":\"" + docId + "\"}";
}
}
}

View File

@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.io.IOException;

View File

@ -157,11 +157,11 @@ public class TransportRankEvalAction
private RatedRequest specification;
private Map<String, EvalQueryQuality> requestDetails;
private Map<String, Exception> errors;
private RankedListQualityMetric metric;
private EvaluationMetric metric;
private AtomicInteger responseCounter;
public RankEvalActionListener(ActionListener<RankEvalResponse> listener,
RankedListQualityMetric metric, RatedRequest specification,
EvaluationMetric metric, RatedRequest specification,
Map<String, EvalQueryQuality> details, Map<String, Exception> errors,
AtomicInteger responseCounter) {
this.listener = listener;

View File

@ -19,6 +19,7 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -36,20 +37,23 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.filterUnknownDocuments;
import static org.elasticsearch.index.rankeval.EvaluationMetric.filterUnknownDocuments;
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
public class DiscountedCumulativeGainTests extends ESTestCase {
/**
* Assuming the docs are ranked in the following order:
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
* log_2(rank + 1)
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* -------------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5 4 | 0 | 0.0 | 2.321928094887362 | 0.0 5 | 1 | 1.0
* | 2.584962500721156 | 0.38685280723454163 6 | 2 | 3.0 | 2.807354922057604
* | 1.0686215613240666
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5
* 4 | 0 | 0.0 | 2.321928094887362 | 0.0
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
* 6 | 2 | 3.0 | 2.807354922057604 | 1.0686215613240666
*
* dcg = 13.84826362927298 (sum of last column)
*/
@ -69,17 +73,18 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
* Check with normalization: to get the maximal possible dcg, sort documents by
* relevance in descending order
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
* log_2(rank + 1)
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* ---------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0  | 7.0 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
* 3 | 2 | 3.0 | 2.0  | 1.5 4 | 2 | 3.0 | 2.321928094887362 
* | 1.2920296742201793 5 | 1 | 1.0 | 2.584962500721156  | 0.38685280723454163 6
* | 0 | 0.0 | 2.807354922057604  | 0.0
* 1 | 3 | 7.0 | 1.0  | 7.0
* 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
* 3 | 2 | 3.0 | 2.0  | 1.5
* 4 | 2 | 3.0 | 2.321928094887362 | 1.2920296742201793
* 5 | 1 | 1.0 | 2.584962500721156  | 0.38685280723454163
* 6 | 0 | 0.0 | 2.807354922057604  | 0.0
*
* idcg = 14.595390756454922 (sum of last column)
*/
dcg.setNormalize(true);
dcg = new DiscountedCumulativeGain(true, null);
assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
}
@ -87,12 +92,14 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
* This tests metric when some documents in the search result don't have a
* rating provided by the user.
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
* log_2(rank + 1)
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* -------------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5 4 | n/a | n/a | n/a | n/a 5 | 1 | 1.0
* | 2.584962500721156 | 0.38685280723454163 6 | n/a | n/a | n/a | n/a
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5
* 4 | n/a | n/a | n/a | n/a
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
* 6 | n/a | n/a | n/a | n/a
*
* dcg = 12.779642067948913 (sum of last column)
*/
@ -118,16 +125,18 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
* Check with normalization: to get the maximal possible dcg, sort documents by
* relevance in descending order
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
* log_2(rank + 1)
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* ----------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0  | 7.0 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
* 3 | 2 | 3.0 | 2.0  | 1.5 4 | 1 | 1.0 | 2.321928094887362   | 0.43067655807339
* 5 | n.a | n.a | n.a.  | n.a. 6 | n.a | n.a | n.a  | n.a
* 1 | 3 | 7.0 | 1.0  | 7.0
* 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
* 3 | 2 | 3.0 | 2.0  | 1.5
* 4 | 1 | 1.0 | 2.321928094887362   | 0.43067655807339
* 5 | n.a | n.a | n.a.  | n.a.
* 6 | n.a | n.a | n.a  | n.a
*
* idcg = 13.347184833073591 (sum of last column)
*/
dcg.setNormalize(true);
dcg = new DiscountedCumulativeGain(true, null);
assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
}
@ -136,13 +145,15 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
* documents than search hits because we restrict DCG to be calculated at the
* fourth position
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
* log_2(rank + 1)
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* -------------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5 4 | n/a | n/a | n/a | n/a
* ----------------------------------------------------------------- 5 | 1 | 1.0
* | 2.584962500721156 | 0.38685280723454163 6 | n/a | n/a | n/a | n/a
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
* 3 | 3 | 7.0 | 2.0 | 3.5
* 4 | n/a | n/a | n/a | n/a
* -----------------------------------------------------------------
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
* 6 | n/a | n/a | n/a | n/a
*
* dcg = 12.392789260714371 (sum of last column until position 4)
*/
@ -171,22 +182,24 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
* Check with normalization: to get the maximal possible dcg, sort documents by
* relevance in descending order
*
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
* log_2(rank + 1)
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
* ---------------------------------------------------------------------------------------
* 1 | 3 | 7.0 | 1.0  | 7.0 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
* 3 | 2 | 3.0 | 2.0  | 1.5 4 | 1 | 1.0 | 2.321928094887362   | 0.43067655807339
* 1 | 3 | 7.0 | 1.0  | 7.0
* 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
* 3 | 2 | 3.0 | 2.0  | 1.5
* 4 | 1 | 1.0 | 2.321928094887362   | 0.43067655807339
* ---------------------------------------------------------------------------------------
* 5 | n.a | n.a | n.a.  | n.a. 6 | n.a | n.a | n.a  | n.a
* 5 | n.a | n.a | n.a.  | n.a.
* 6 | n.a | n.a | n.a  | n.a
*
* idcg = 13.347184833073591 (sum of last column)
*/
dcg.setNormalize(true);
dcg = new DiscountedCumulativeGain(true, null);
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001);
}
public void testParseFromXContent() throws IOException {
String xContent = " {\n" + " \"unknown_doc_rating\": 2,\n" + " \"normalize\": true\n" + "}";
String xContent = " { \"unknown_doc_rating\": 2, \"normalize\": true }";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser);
assertEquals(2, dcgAt.getUnknownDocRating().intValue());
@ -217,29 +230,25 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
public void testSerialization() throws IOException {
DiscountedCumulativeGain original = createTestItem();
DiscountedCumulativeGain deserialized = RankEvalTestHelper.copy(original, DiscountedCumulativeGain::new);
DiscountedCumulativeGain deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()),
DiscountedCumulativeGain::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
DiscountedCumulativeGain testItem = createTestItem();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
RankEvalTestHelper.copy(testItem, DiscountedCumulativeGain::new));
checkEqualsAndHashCode(createTestItem(), original -> {
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating());
}, DiscountedCumulativeGainTests::mutateTestItem);
}
private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) {
boolean normalise = original.getNormalize();
int unknownDocRating = original.getUnknownDocRating();
DiscountedCumulativeGain gain = new DiscountedCumulativeGain();
gain.setNormalize(normalise);
gain.setUnknownDocRating(unknownDocRating);
List<Runnable> mutators = new ArrayList<>();
mutators.add(() -> gain.setNormalize(!original.getNormalize()));
mutators.add(() -> gain.setUnknownDocRating(randomValueOtherThan(unknownDocRating, () -> randomIntBetween(0, 10))));
randomFrom(mutators).run();
return gain;
if (randomBoolean()) {
return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating());
} else {
return new DiscountedCumulativeGain(original.getNormalize(),
randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)));
}
}
}

View File

@ -1,67 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
public class DocumentKeyTests extends ESTestCase {
static DocumentKey createRandomRatedDocumentKey() {
String index = randomAlphaOfLengthBetween(1, 10);
String docId = randomAlphaOfLengthBetween(1, 10);
return new DocumentKey(index, docId);
}
public DocumentKey createTestItem() {
return createRandomRatedDocumentKey();
}
public DocumentKey mutateTestItem(DocumentKey original) {
String index = original.getIndex();
String docId = original.getDocID();
switch (randomIntBetween(0, 1)) {
case 0:
index = index + "_";
break;
case 1:
docId = docId + "_";
break;
default:
throw new IllegalStateException("The test should only allow two parameters mutated");
}
return new DocumentKey(index, docId);
}
public void testEqualsAndHash() throws IOException {
DocumentKey testItem = createRandomRatedDocumentKey();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
new DocumentKey(testItem.getIndex(), testItem.getDocID()));
}
public void testSerialization() throws IOException {
DocumentKey original = createTestItem();
DocumentKey deserialized = RankEvalTestHelper.copy(original, DocumentKey::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
}

View File

@ -20,22 +20,24 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
public class EvalQueryQualityTests extends ESTestCase {
private static NamedWriteableRegistry namedWritableRegistry = new NamedWriteableRegistry(
new RankEvalPlugin().getNamedWriteables());
private static NamedWriteableRegistry namedWritableRegistry = new NamedWriteableRegistry(new RankEvalPlugin().getNamedWriteables());
public static EvalQueryQuality randomEvalQueryQuality() {
List<DocumentKey> unknownDocs = new ArrayList<>();
int numberOfUnknownDocs = randomInt(5);
for (int i = 0; i < numberOfUnknownDocs; i++) {
unknownDocs.add(DocumentKeyTests.createRandomRatedDocumentKey());
unknownDocs.add(new DocumentKey(randomAlphaOfLength(10), randomAlphaOfLength(10)));
}
int numberOfSearchHits = randomInt(5);
List<RatedSearchHit> ratedHits = new ArrayList<>();
@ -54,17 +56,18 @@ public class EvalQueryQualityTests extends ESTestCase {
public void testSerialization() throws IOException {
EvalQueryQuality original = randomEvalQueryQuality();
EvalQueryQuality deserialized = RankEvalTestHelper.copy(original, EvalQueryQuality::new,
namedWritableRegistry);
EvalQueryQuality deserialized = copy(original);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
private static EvalQueryQuality copy(EvalQueryQuality original) throws IOException {
return ESTestCase.copyWriteable(original, namedWritableRegistry, EvalQueryQuality::new);
}
public void testEqualsAndHash() throws IOException {
EvalQueryQuality testItem = randomEvalQueryQuality();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
RankEvalTestHelper.copy(testItem, EvalQueryQuality::new, namedWritableRegistry));
checkEqualsAndHashCode(randomEvalQueryQuality(), EvalQueryQualityTests::copy, EvalQueryQualityTests::mutateTestItem);
}
private static EvalQueryQuality mutateTestItem(EvalQueryQuality original) {

View File

@ -19,14 +19,15 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.rankeval.PrecisionTests.Rating;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.test.ESTestCase;
@ -38,21 +39,35 @@ import java.util.Collections;
import java.util.List;
import java.util.Vector;
public class ReciprocalRankTests extends ESTestCase {
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
public class MeanReciprocalRankTests extends ESTestCase {
public void testParseFromXContent() throws IOException {
String xContent = "{ }";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(1, mrr.getRelevantRatingThreshold());
}
xContent = "{ \"relevant_rating_threshold\": 2 }";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
assertEquals(2, mrr.getRelevantRatingThreshold());
}
}
public void testMaxAcceptableRank() {
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank();
int searchHits = randomIntBetween(1, 50);
SearchHit[] hits = createSearchHits(0, searchHits, "test");
List<RatedDocument> ratedDocs = new ArrayList<>();
int relevantAt = randomIntBetween(0, searchHits);
for (int i = 0; i <= searchHits; i++) {
if (i == relevantAt) {
ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.RELEVANT.ordinal()));
ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.RELEVANT.ordinal()));
} else {
ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.IRRELEVANT.ordinal()));
ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.IRRELEVANT.ordinal()));
}
}
@ -76,9 +91,9 @@ public class ReciprocalRankTests extends ESTestCase {
int relevantAt = randomIntBetween(0, 9);
for (int i = 0; i <= 20; i++) {
if (i == relevantAt) {
ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.RELEVANT.ordinal()));
ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.RELEVANT.ordinal()));
} else {
ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.IRRELEVANT.ordinal()));
ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.IRRELEVANT.ordinal()));
}
}
@ -101,8 +116,7 @@ public class ReciprocalRankTests extends ESTestCase {
rated.add(new RatedDocument("test", "4", 4));
SearchHit[] hits = createSearchHits(0, 5, "test");
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank();
reciprocalRank.setRelevantRatingThreshhold(2);
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2);
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated);
assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001);
assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
@ -153,35 +167,31 @@ public class ReciprocalRankTests extends ESTestCase {
}
private static MeanReciprocalRank createTestItem() {
MeanReciprocalRank testItem = new MeanReciprocalRank();
testItem.setRelevantRatingThreshhold(randomIntBetween(0, 20));
return testItem;
return new MeanReciprocalRank(randomIntBetween(0, 20));
}
public void testSerialization() throws IOException {
MeanReciprocalRank original = createTestItem();
MeanReciprocalRank deserialized = RankEvalTestHelper.copy(original, MeanReciprocalRank::new);
MeanReciprocalRank deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()),
MeanReciprocalRank::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
MeanReciprocalRank testItem = createTestItem();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
RankEvalTestHelper.copy(testItem, MeanReciprocalRank::new));
checkEqualsAndHashCode(createTestItem(), MeanReciprocalRankTests::copy, MeanReciprocalRankTests::mutate);
}
private static MeanReciprocalRank mutateTestItem(MeanReciprocalRank testItem) {
int relevantThreshold = testItem.getRelevantRatingThreshold();
MeanReciprocalRank rank = new MeanReciprocalRank();
rank.setRelevantRatingThreshhold(randomValueOtherThan(relevantThreshold, () -> randomIntBetween(0, 10)));
return rank;
private static MeanReciprocalRank copy(MeanReciprocalRank testItem) {
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold());
}
private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) {
return new MeanReciprocalRank(randomValueOtherThan(testItem.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)));
}
public void testInvalidRelevantThreshold() {
MeanReciprocalRank prez = new MeanReciprocalRank();
expectThrows(IllegalArgumentException.class, () -> prez.setRelevantRatingThreshhold(-1));
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1));
}
}

View File

@ -19,6 +19,7 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -38,11 +39,13 @@ import java.util.Collections;
import java.util.List;
import java.util.Vector;
public class PrecisionTests extends ESTestCase {
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
public class PrecisionAtKTests extends ESTestCase {
public void testPrecisionAtFiveCalculation() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal()));
rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal()));
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals(1, evaluated.getQualityLevel(), 0.00001);
assertEquals(1, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
@ -51,11 +54,11 @@ public class PrecisionTests extends ESTestCase {
public void testPrecisionAtFiveIgnoreOneResult() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "1", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "2", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "3", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "4", Rating.IRRELEVANT.ordinal()));
rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal()));
rated.add(createRatedDoc("test", "1", TestRatingEnum.RELEVANT.ordinal()));
rated.add(createRatedDoc("test", "2", TestRatingEnum.RELEVANT.ordinal()));
rated.add(createRatedDoc("test", "3", TestRatingEnum.RELEVANT.ordinal()));
rated.add(createRatedDoc("test", "4", TestRatingEnum.IRRELEVANT.ordinal()));
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals((double) 4 / 5, evaluated.getQualityLevel(), 0.00001);
assertEquals(4, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
@ -69,13 +72,12 @@ public class PrecisionTests extends ESTestCase {
*/
public void testPrecisionAtFiveRelevanceThreshold() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(new RatedDocument("test", "0", 0));
rated.add(new RatedDocument("test", "1", 1));
rated.add(new RatedDocument("test", "2", 2));
rated.add(new RatedDocument("test", "3", 3));
rated.add(new RatedDocument("test", "4", 4));
PrecisionAtK precisionAtN = new PrecisionAtK();
precisionAtN.setRelevantRatingThreshhold(2);
rated.add(createRatedDoc("test", "0", 0));
rated.add(createRatedDoc("test", "1", 1));
rated.add(createRatedDoc("test", "2", 2));
rated.add(createRatedDoc("test", "3", 3));
rated.add(createRatedDoc("test", "4", 4));
PrecisionAtK precisionAtN = new PrecisionAtK(2, false);
EvalQueryQuality evaluated = precisionAtN.evaluate("id", toSearchHits(rated, "test"), rated);
assertEquals((double) 3 / 5, evaluated.getQualityLevel(), 0.00001);
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
@ -84,11 +86,11 @@ public class PrecisionTests extends ESTestCase {
public void testPrecisionAtFiveCorrectIndex() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(new RatedDocument("test_other", "0", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test_other", "1", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "1", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "2", Rating.IRRELEVANT.ordinal()));
rated.add(createRatedDoc("test_other", "0", TestRatingEnum.RELEVANT.ordinal()));
rated.add(createRatedDoc("test_other", "1", TestRatingEnum.RELEVANT.ordinal()));
rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal()));
rated.add(createRatedDoc("test", "1", TestRatingEnum.RELEVANT.ordinal()));
rated.add(createRatedDoc("test", "2", TestRatingEnum.IRRELEVANT.ordinal()));
// the following search hits contain only the last three documents
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated.subList(2, 5), "test"), rated);
assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001);
@ -98,8 +100,8 @@ public class PrecisionTests extends ESTestCase {
public void testIgnoreUnlabeled() {
List<RatedDocument> rated = new ArrayList<>();
rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal()));
rated.add(new RatedDocument("test", "1", Rating.RELEVANT.ordinal()));
rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal()));
rated.add(createRatedDoc("test", "1", TestRatingEnum.RELEVANT.ordinal()));
// add an unlabeled search hit
SearchHit[] searchHits = Arrays.copyOf(toSearchHits(rated, "test"), 3);
searchHits[2] = new SearchHit(2, "2", new Text("testtype"), Collections.emptyMap());
@ -111,8 +113,7 @@ public class PrecisionTests extends ESTestCase {
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
// also try with setting `ignore_unlabeled`
PrecisionAtK prec = new PrecisionAtK();
prec.setIgnoreUnlabeled(true);
PrecisionAtK prec = new PrecisionAtK(1, true);
evaluated = prec.evaluate("id", searchHits, rated);
assertEquals((double) 2 / 2, evaluated.getQualityLevel(), 0.00001);
assertEquals(2, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
@ -131,8 +132,7 @@ public class PrecisionTests extends ESTestCase {
assertEquals(5, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
// also try with setting `ignore_unlabeled`
PrecisionAtK prec = new PrecisionAtK();
prec.setIgnoreUnlabeled(true);
PrecisionAtK prec = new PrecisionAtK(1, true);
evaluated = prec.evaluate("id", hits, Collections.emptyList());
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
assertEquals(0, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
@ -158,16 +158,11 @@ public class PrecisionTests extends ESTestCase {
public void testInvalidRelevantThreshold() {
PrecisionAtK prez = new PrecisionAtK();
expectThrows(IllegalArgumentException.class, () -> prez.setRelevantRatingThreshhold(-1));
expectThrows(IllegalArgumentException.class, () -> new PrecisionAtK(-1, false));
}
public static PrecisionAtK createTestItem() {
PrecisionAtK precision = new PrecisionAtK();
if (randomBoolean()) {
precision.setRelevantRatingThreshhold(randomIntBetween(0, 10));
}
precision.setIgnoreUnlabeled(randomBoolean());
return precision;
return new PrecisionAtK(randomIntBetween(0, 10), randomBoolean());
}
public void testXContentRoundtrip() throws IOException {
@ -186,29 +181,28 @@ public class PrecisionTests extends ESTestCase {
public void testSerialization() throws IOException {
PrecisionAtK original = createTestItem();
PrecisionAtK deserialized = RankEvalTestHelper.copy(original, PrecisionAtK::new);
PrecisionAtK deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()),
PrecisionAtK::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
PrecisionAtK testItem = createTestItem();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), RankEvalTestHelper.copy(testItem, PrecisionAtK::new));
checkEqualsAndHashCode(createTestItem(), PrecisionAtKTests::copy, PrecisionAtKTests::mutate);
}
private static PrecisionAtK mutateTestItem(PrecisionAtK original) {
boolean ignoreUnlabeled = original.getIgnoreUnlabeled();
int relevantThreshold = original.getRelevantRatingThreshold();
PrecisionAtK precision = new PrecisionAtK();
precision.setIgnoreUnlabeled(ignoreUnlabeled);
precision.setRelevantRatingThreshhold(relevantThreshold);
private static PrecisionAtK copy(PrecisionAtK original) {
return new PrecisionAtK(original.getRelevantRatingThreshold(), original.getIgnoreUnlabeled());
}
List<Runnable> mutators = new ArrayList<>();
mutators.add(() -> precision.setIgnoreUnlabeled(!ignoreUnlabeled));
mutators.add(() -> precision.setRelevantRatingThreshhold(randomValueOtherThan(relevantThreshold, () -> randomIntBetween(0, 10))));
randomFrom(mutators).run();
return precision;
private static PrecisionAtK mutate(PrecisionAtK original) {
if (randomBoolean()) {
return new PrecisionAtK(original.getRelevantRatingThreshold(), !original.getIgnoreUnlabeled());
} else {
return new PrecisionAtK(randomValueOtherThan(original.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)),
original.getIgnoreUnlabeled());
}
}
private static SearchHit[] toSearchHits(List<RatedDocument> rated, String index) {
@ -220,7 +214,7 @@ public class PrecisionTests extends ESTestCase {
return hits;
}
public enum Rating {
IRRELEVANT, RELEVANT;
private static RatedDocument createRatedDoc(String index, String id, int rating) {
return new RatedDocument(index, id, rating);
}
}

View File

@ -22,7 +22,6 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.rankeval.PrecisionTests.Rating;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
@ -35,7 +34,7 @@ import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.filterUnknownDocuments;
import static org.elasticsearch.index.rankeval.EvaluationMetric.filterUnknownDocuments;
public class RankEvalRequestIT extends ESIntegTestCase {
@Override
@ -82,8 +81,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
specifications.add(berlinRequest);
PrecisionAtK metric = new PrecisionAtK();
metric.setIgnoreUnlabeled(true);
PrecisionAtK metric = new PrecisionAtK(1, true);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
@ -106,7 +104,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
if (id.equals("1") || id.equals("6")) {
assertFalse(hit.getRating().isPresent());
} else {
assertEquals(Rating.RELEVANT.ordinal(), hit.getRating().get().intValue());
assertEquals(TestRatingEnum.RELEVANT.ordinal(), hit.getRating().get().intValue());
}
}
}
@ -117,7 +115,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
for (RatedSearchHit hit : hitsAndRatings) {
String id = hit.getSearchHit().getId();
if (id.equals("1")) {
assertEquals(Rating.RELEVANT.ordinal(), hit.getRating().get().intValue());
assertEquals(TestRatingEnum.RELEVANT.ordinal(), hit.getRating().get().intValue());
} else {
assertFalse(hit.getRating().isPresent());
}
@ -167,7 +165,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
private static List<RatedDocument> createRelevant(String... docs) {
List<RatedDocument> relevant = new ArrayList<>();
for (String doc : docs) {
relevant.add(new RatedDocument("test", doc, Rating.RELEVANT.ordinal()));
relevant.add(new RatedDocument("test", doc, TestRatingEnum.RELEVANT.ordinal()));
}
return relevant;
}

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
@ -43,7 +44,7 @@ public class RankEvalResponseTests extends ESTestCase {
int numberOfUnknownDocs = randomIntBetween(0, 5);
List<DocumentKey> unknownDocs = new ArrayList<>(numberOfUnknownDocs);
for (int d = 0; d < numberOfUnknownDocs; d++) {
unknownDocs.add(DocumentKeyTests.createRandomRatedDocumentKey());
unknownDocs.add(new DocumentKey(randomAlphaOfLength(10), randomAlphaOfLength(10)));
}
EvalQueryQuality evalQuality = new EvalQueryQuality(id,
randomDoubleBetween(0.0, 1.0, true));
@ -65,12 +66,9 @@ public class RankEvalResponseTests extends ESTestCase {
try (StreamInput in = output.bytes().streamInput()) {
RankEvalResponse deserializedResponse = new RankEvalResponse();
deserializedResponse.readFrom(in);
assertEquals(randomResponse.getQualityLevel(),
deserializedResponse.getQualityLevel(), Double.MIN_VALUE);
assertEquals(randomResponse.getPartialResults(),
deserializedResponse.getPartialResults());
assertEquals(randomResponse.getFailures().keySet(),
deserializedResponse.getFailures().keySet());
assertEquals(randomResponse.getQualityLevel(), deserializedResponse.getQualityLevel(), Double.MIN_VALUE);
assertEquals(randomResponse.getPartialResults(), deserializedResponse.getPartialResults());
assertEquals(randomResponse.getFailures().keySet(), deserializedResponse.getFailures().keySet());
assertNotSame(randomResponse, deserializedResponse);
assertEquals(-1, in.read());
}

View File

@ -45,6 +45,8 @@ import java.util.Map;
import java.util.Map.Entry;
import java.util.function.Supplier;
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
public class RankEvalSpecTests extends ESTestCase {
private static <T> List<T> randomList(Supplier<T> randomSupplier) {
@ -57,9 +59,9 @@ public class RankEvalSpecTests extends ESTestCase {
}
private static RankEvalSpec createTestItem() throws IOException {
RankedListQualityMetric metric;
EvaluationMetric metric;
if (randomBoolean()) {
metric = PrecisionTests.createTestItem();
metric = PrecisionAtKTests.createTestItem();
} else {
metric = DiscountedCumulativeGainTests.createTestItem();
}
@ -111,41 +113,30 @@ public class RankEvalSpecTests extends ESTestCase {
public void testSerialization() throws IOException {
RankEvalSpec original = createTestItem();
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGain.NAME,
DiscountedCumulativeGain::new));
namedWriteables
.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
RankEvalSpec deserialized = RankEvalTestHelper.copy(original, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables));
RankEvalSpec deserialized = copy(original);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
RankEvalSpec testItem = createTestItem();
private static RankEvalSpec copy(RankEvalSpec original) throws IOException {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGain.NAME,
DiscountedCumulativeGain::new));
namedWriteables
.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
RankEvalSpec mutant = RankEvalTestHelper.copy(testItem, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables));
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(mutant),
RankEvalTestHelper.copy(testItem, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables)));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
return ESTestCase.copyWriteable(original, new NamedWriteableRegistry(namedWriteables), RankEvalSpec::new);
}
private static RankEvalSpec mutateTestItem(RankEvalSpec mutant) {
Collection<RatedRequest> ratedRequests = mutant.getRatedRequests();
RankedListQualityMetric metric = mutant.getMetric();
Map<String, Script> templates = mutant.getTemplates();
public void testEqualsAndHash() throws IOException {
checkEqualsAndHashCode(createTestItem(), RankEvalSpecTests::copy, RankEvalSpecTests::mutateTestItem);
}
private static RankEvalSpec mutateTestItem(RankEvalSpec original) {
List<RatedRequest> ratedRequests = new ArrayList<>(original.getRatedRequests());
EvaluationMetric metric = original.getMetric();
Map<String, Script> templates = original.getTemplates();
int mutate = randomIntBetween(0, 2);
switch (mutate) {
@ -177,7 +168,7 @@ public class RankEvalSpecTests extends ESTestCase {
}
public void testMissingRatedRequestsFailsParsing() {
RankedListQualityMetric metric = new PrecisionAtK();
EvaluationMetric metric = new PrecisionAtK();
expectThrows(IllegalStateException.class, () -> new RankEvalSpec(new ArrayList<>(), metric));
expectThrows(IllegalStateException.class, () -> new RankEvalSpec(null, metric));
}
@ -189,7 +180,7 @@ public class RankEvalSpecTests extends ESTestCase {
}
public void testMissingTemplateAndSearchRequestFailsParsing() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
Map<String, Object> params = new HashMap<>();
params.put("key", "value");

View File

@ -1,94 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;
import java.io.IOException;
import java.util.Collections;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
// TODO replace by infra from ESTestCase
public class RankEvalTestHelper {
public static <T> void testHashCodeAndEquals(T testItem, T mutation, T secondCopy) {
assertFalse("testItem is equal to null", testItem.equals(null));
assertFalse("testItem is equal to incompatible type", testItem.equals(""));
assertTrue("testItem is not equal to self", testItem.equals(testItem));
assertThat("same testItem's hashcode returns different values if called multiple times",
testItem.hashCode(), equalTo(testItem.hashCode()));
assertThat("different testItem should not be equal", mutation, not(equalTo(testItem)));
assertNotSame("testItem copy is not same as original", testItem, secondCopy);
assertTrue("testItem is not equal to its copy", testItem.equals(secondCopy));
assertTrue("equals is not symmetric", secondCopy.equals(testItem));
assertThat("testItem copy's hashcode is different from original hashcode",
secondCopy.hashCode(), equalTo(testItem.hashCode()));
}
/**
* Make a deep copy of an object by running it through a BytesStreamOutput
*
* @param original
* the original object
* @param reader
* a function able to create a new copy of this type
* @return a new copy of the original object
*/
public static <T extends Writeable> T copy(T original, Writeable.Reader<T> reader)
throws IOException {
return copy(original, reader, new NamedWriteableRegistry(Collections.emptyList()));
}
/**
* Make a deep copy of an object by running it through a BytesStreamOutput
*
* @param original
* the original object
* @param reader
* a function able to create a new copy of this type
* @param namedWriteableRegistry
* must be non-empty if the object itself or nested object
* implement {@link NamedWriteable}
* @return a new copy of the original object
*/
public static <T extends Writeable> T copy(T original, Writeable.Reader<T> reader,
NamedWriteableRegistry namedWriteableRegistry) throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput()) {
original.writeTo(output);
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(),
namedWriteableRegistry)) {
return reader.read(in);
}
}
}
}

View File

@ -19,6 +19,7 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
@ -27,15 +28,14 @@ import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.Collections;
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
public class RatedDocumentTests extends ESTestCase {
public static RatedDocument createRatedDocument() {
String index = randomAlphaOfLength(10);
String docId = randomAlphaOfLength(10);
int rating = randomInt();
return new RatedDocument(index, docId, rating);
return new RatedDocument(randomAlphaOfLength(10), randomAlphaOfLength(10), randomInt());
}
public void testXContentParsing() throws IOException {
@ -52,22 +52,17 @@ public class RatedDocumentTests extends ESTestCase {
public void testSerialization() throws IOException {
RatedDocument original = createRatedDocument();
RatedDocument deserialized = RankEvalTestHelper.copy(original, RatedDocument::new);
RatedDocument deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()),
RatedDocument::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
RatedDocument testItem = createRatedDocument();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), RankEvalTestHelper.copy(testItem, RatedDocument::new));
}
public void testInvalidParsing() {
expectThrows(IllegalArgumentException.class, () -> new RatedDocument(null, "abc", 10));
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("", "abc", 10));
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", "", 10));
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", null, 10));
checkEqualsAndHashCode(createRatedDocument(), original -> {
return new RatedDocument(original.getIndex(), original.getDocID(), original.getRating());
}, RatedDocumentTests::mutateTestItem);
}
private static RatedDocument mutateTestItem(RatedDocument original) {

View File

@ -47,6 +47,7 @@ import java.util.stream.Stream;
import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.toList;
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
public class RatedRequestsTests extends ESTestCase {
@ -140,32 +141,26 @@ public class RatedRequestsTests extends ESTestCase {
for (int i = 0; i < size; i++) {
indices.add(randomAlphaOfLengthBetween(0, 50));
}
RatedRequest original = createTestItem(indices, randomBoolean());
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
RatedRequest deserialized = RankEvalTestHelper.copy(original, RatedRequest::new, new NamedWriteableRegistry(namedWriteables));
RatedRequest deserialized = copy(original);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
private static RatedRequest copy(RatedRequest original) throws IOException {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
return ESTestCase.copyWriteable(original, new NamedWriteableRegistry(namedWriteables), RatedRequest::new);
}
public void testEqualsAndHash() throws IOException {
List<String> indices = new ArrayList<>();
int size = randomIntBetween(0, 20);
for (int i = 0; i < size; i++) {
indices.add(randomAlphaOfLengthBetween(0, 50));
}
RatedRequest testItem = createTestItem(indices, randomBoolean());
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
RankEvalTestHelper.copy(testItem, RatedRequest::new, new NamedWriteableRegistry(namedWriteables)));
checkEqualsAndHashCode(createTestItem(indices, randomBoolean()), RatedRequestsTests::copy, RatedRequestsTests::mutateTestItem);
}
private static RatedRequest mutateTestItem(RatedRequest original) {
@ -220,8 +215,7 @@ public class RatedRequestsTests extends ESTestCase {
}
public void testDuplicateRatedDocThrowsException() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1),
new RatedDocument(new DocumentKey("index1", "id1"), 5));
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1), new RatedDocument("index1", "id1", 5));
// search request set, no summary fields
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
@ -237,45 +231,45 @@ public class RatedRequestsTests extends ESTestCase {
}
public void testNullSummaryFieldsTreatment() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
RatedRequest request = new RatedRequest("id", ratedDocs, new SearchSourceBuilder());
expectThrows(IllegalArgumentException.class, () -> request.setSummaryFields(null));
}
public void testNullParamsTreatment() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
RatedRequest request = new RatedRequest("id", ratedDocs, new SearchSourceBuilder(), null, null);
assertNotNull(request.getParams());
}
public void testSettingParamsAndRequestThrows() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
Map<String, Object> params = new HashMap<>();
params.put("key", "value");
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder(), params, null));
}
public void testSettingNeitherParamsNorRequestThrows() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, null));
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, new HashMap<>(), "templateId"));
}
public void testSettingParamsWithoutTemplateIdThrows() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
Map<String, Object> params = new HashMap<>();
params.put("key", "value");
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, params, null));
}
public void testSettingTemplateIdAndRequestThrows() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
expectThrows(IllegalArgumentException.class,
() -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder(), null, "templateId"));
}
public void testSettingTemplateIdNoParamsThrows() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, null, "templateId"));
}

View File

@ -19,6 +19,7 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
@ -27,6 +28,8 @@ import java.io.IOException;
import java.util.Collections;
import java.util.Optional;
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
public class RatedSearchHitTests extends ESTestCase {
public static RatedSearchHit randomRatedSearchHit() {
@ -57,15 +60,18 @@ public class RatedSearchHitTests extends ESTestCase {
public void testSerialization() throws IOException {
RatedSearchHit original = randomRatedSearchHit();
RatedSearchHit deserialized = RankEvalTestHelper.copy(original, RatedSearchHit::new);
RatedSearchHit deserialized = copy(original);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
RatedSearchHit testItem = randomRatedSearchHit();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
RankEvalTestHelper.copy(testItem, RatedSearchHit::new));
checkEqualsAndHashCode(randomRatedSearchHit(), RatedSearchHitTests::copy, RatedSearchHitTests::mutateTestItem);
}
private static RatedSearchHit copy(RatedSearchHit original) throws IOException {
return ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()), RatedSearchHit::new);
}
}

View File

@ -0,0 +1,24 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
enum TestRatingEnum {
IRRELEVANT, RELEVANT;
}

View File

@ -35,4 +35,4 @@
- match: { rank_eval.details.amsterdam_query.unknown_docs: [ ]}
- match: { rank_eval.details.amsterdam_query.metric_details: {"relevant_docs_retrieved": 1, "docs_retrieved": 1}}
- is_true: rank_eval.failures.invalid_queryy
- is_true: rank_eval.failures.invalid_query