mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-23 05:15:04 +00:00
Improving and cleaning up tests
Removing the unnecessary RankEvalTestHelper, making use of the common test infra in ESTestCase, also hardening a few of the classes by making more fields final.
This commit is contained in:
parent
5c65a59369
commit
e278c1d17d
@ -22,7 +22,7 @@ package org.elasticsearch.index.rankeval;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
@ -35,22 +35,25 @@ import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
|
||||
|
||||
public class DiscountedCumulativeGain implements RankedListQualityMetric {
|
||||
public class DiscountedCumulativeGain implements EvaluationMetric {
|
||||
|
||||
/** If set to true, the dcg will be normalized (ndcg) */
|
||||
private boolean normalize;
|
||||
private final boolean normalize;
|
||||
|
||||
/**
|
||||
* If set to, this will be the rating for docs the user hasn't supplied an
|
||||
* explicit rating for
|
||||
*/
|
||||
private Integer unknownDocRating;
|
||||
private final Integer unknownDocRating;
|
||||
|
||||
public static final String NAME = "dcg";
|
||||
private static final double LOG2 = Math.log(2.0);
|
||||
|
||||
public DiscountedCumulativeGain() {
|
||||
this(false, null);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -82,13 +85,6 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
/**
|
||||
* If set to true, the dcg will be normalized (ndcg)
|
||||
*/
|
||||
public void setNormalize(boolean normalize) {
|
||||
this.normalize = normalize;
|
||||
}
|
||||
|
||||
/**
|
||||
* check whether this metric computes only dcg or "normalized" ndcg
|
||||
*/
|
||||
@ -96,13 +92,6 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
|
||||
return this.normalize;
|
||||
}
|
||||
|
||||
/**
|
||||
* the rating for docs the user hasn't supplied an explicit rating for
|
||||
*/
|
||||
public void setUnknownDocRating(int unknownDocRating) {
|
||||
this.unknownDocRating = unknownDocRating;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the rating used for unrated documents
|
||||
*/
|
||||
@ -118,10 +107,10 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
|
||||
List<RatedSearchHit> ratedHits = joinHitsWithRatings(hits, ratedDocs);
|
||||
List<Integer> ratingsInSearchHits = new ArrayList<>(ratedHits.size());
|
||||
for (RatedSearchHit hit : ratedHits) {
|
||||
// unknownDocRating might be null, which means it will be unrated
|
||||
// docs are ignored in the dcg calculation
|
||||
// we still need to add them as a placeholder so the rank of the
|
||||
// subsequent ratings is correct
|
||||
// unknownDocRating might be null, which means it will be unrated docs are
|
||||
// ignored in the dcg calculation
|
||||
// we still need to add them as a placeholder so the rank of the subsequent
|
||||
// ratings is correct
|
||||
ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating));
|
||||
}
|
||||
double dcg = computeDCG(ratingsInSearchHits);
|
||||
@ -151,12 +140,15 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
|
||||
|
||||
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
|
||||
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
|
||||
private static final ObjectParser<DiscountedCumulativeGain, Void> PARSER = new ObjectParser<>(
|
||||
"dcg_at", () -> new DiscountedCumulativeGain());
|
||||
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at",
|
||||
args -> {
|
||||
Boolean normalized = (Boolean) args[0];
|
||||
return new DiscountedCumulativeGain(normalized == null ? false : normalized, (Integer) args[1]);
|
||||
});
|
||||
|
||||
static {
|
||||
PARSER.declareBoolean(DiscountedCumulativeGain::setNormalize, NORMALIZE_FIELD);
|
||||
PARSER.declareInt(DiscountedCumulativeGain::setUnknownDocRating, UNKNOWN_DOC_RATING_FIELD);
|
||||
PARSER.declareBoolean(optionalConstructorArg(), NORMALIZE_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), UNKNOWN_DOC_RATING_FIELD);
|
||||
}
|
||||
|
||||
public static DiscountedCumulativeGain fromXContent(XContentParser parser) {
|
||||
@ -193,6 +185,4 @@ public class DiscountedCumulativeGain implements RankedListQualityMetric {
|
||||
public final int hashCode() {
|
||||
return Objects.hash(normalize, unknownDocRating);
|
||||
}
|
||||
|
||||
// TODO maybe also add debugging breakdown here
|
||||
}
|
||||
|
@ -1,106 +0,0 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class DocumentKey implements Writeable, ToXContentObject {
|
||||
|
||||
private String docId;
|
||||
private String index;
|
||||
|
||||
void setIndex(String index) {
|
||||
this.index = index;
|
||||
}
|
||||
|
||||
void setDocId(String docId) {
|
||||
this.docId = docId;
|
||||
}
|
||||
|
||||
public DocumentKey(String index, String docId) {
|
||||
if (Strings.isNullOrEmpty(index)) {
|
||||
throw new IllegalArgumentException("Index must be set for each rated document");
|
||||
}
|
||||
if (Strings.isNullOrEmpty(docId)) {
|
||||
throw new IllegalArgumentException("DocId must be set for each rated document");
|
||||
}
|
||||
|
||||
this.index = index;
|
||||
this.docId = docId;
|
||||
}
|
||||
|
||||
public DocumentKey(StreamInput in) throws IOException {
|
||||
this.index = in.readString();
|
||||
this.docId = in.readString();
|
||||
}
|
||||
|
||||
public String getIndex() {
|
||||
return index;
|
||||
}
|
||||
|
||||
public String getDocID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(index);
|
||||
out.writeString(docId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
DocumentKey other = (DocumentKey) obj;
|
||||
return Objects.equals(index, other.index) && Objects.equals(docId, other.docId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(index, docId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(RatedDocument.INDEX_FIELD.getPreferredName(), index);
|
||||
builder.field(RatedDocument.DOC_ID_FIELD.getPreferredName(), docId);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
@ -24,6 +24,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
@ -91,8 +92,11 @@ public class EvalQueryQuality implements ToXContent, Writeable {
|
||||
builder.startObject(id);
|
||||
builder.field("quality_level", this.qualityLevel);
|
||||
builder.startArray("unknown_docs");
|
||||
for (DocumentKey key : RankedListQualityMetric.filterUnknownDocuments(hits)) {
|
||||
key.toXContent(builder, params);
|
||||
for (DocumentKey key : EvaluationMetric.filterUnknownDocuments(hits)) {
|
||||
builder.startObject();
|
||||
builder.field(RatedDocument.INDEX_FIELD.getPreferredName(), key.getIndex());
|
||||
builder.field(RatedDocument.DOC_ID_FIELD.getPreferredName(), key.getDocId());
|
||||
builder.endObject();
|
||||
}
|
||||
builder.endArray();
|
||||
builder.startArray("hits");
|
||||
|
@ -24,7 +24,9 @@ import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParser.Token;
|
||||
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
@ -35,13 +37,14 @@ import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Classes implementing this interface provide a means to compute the quality of a result list returned by some search.
|
||||
* Implementations of {@link EvaluationMetric} need to provide a way to compute the quality metric for
|
||||
* a result list returned by some search (@link {@link SearchHits}) and a list of rated documents.
|
||||
*/
|
||||
public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
|
||||
public interface EvaluationMetric extends ToXContent, NamedWriteable {
|
||||
|
||||
/**
|
||||
* Returns a single metric representing the ranking quality of a set of returned
|
||||
* documents wrt. to a set of document Ids labeled as relevant for this search.
|
||||
* documents wrt. to a set of document ids labeled as relevant for this search.
|
||||
*
|
||||
* @param taskId
|
||||
* the id of the query for which the ranking is currently evaluated
|
||||
@ -55,15 +58,15 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
|
||||
*/
|
||||
EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs);
|
||||
|
||||
static RankedListQualityMetric fromXContent(XContentParser parser) throws IOException {
|
||||
RankedListQualityMetric rc;
|
||||
static EvaluationMetric fromXContent(XContentParser parser) throws IOException {
|
||||
EvaluationMetric rc;
|
||||
Token token = parser.nextToken();
|
||||
if (token != XContentParser.Token.FIELD_NAME) {
|
||||
throw new ParsingException(parser.getTokenLocation(), "[_na] missing required metric name");
|
||||
}
|
||||
String metricName = parser.currentName();
|
||||
|
||||
// TODO maybe switch to using a plugable registry later?
|
||||
// TODO switch to using a plugable registry
|
||||
switch (metricName) {
|
||||
case PrecisionAtK.NAME:
|
||||
rc = PrecisionAtK.fromXContent(parser);
|
||||
@ -101,13 +104,19 @@ public interface RankedListQualityMetric extends ToXContent, NamedWriteable {
|
||||
return ratedSearchHits;
|
||||
}
|
||||
|
||||
/**
|
||||
* filter @link {@link RatedSearchHit} that don't have a rating
|
||||
*/
|
||||
static List<DocumentKey> filterUnknownDocuments(List<RatedSearchHit> ratedHits) {
|
||||
// join hits with rated documents
|
||||
List<DocumentKey> unknownDocs = ratedHits.stream().filter(hit -> hit.getRating().isPresent() == false)
|
||||
.map(hit -> new DocumentKey(hit.getSearchHit().getIndex(), hit.getSearchHit().getId())).collect(Collectors.toList());
|
||||
return unknownDocs;
|
||||
}
|
||||
|
||||
/**
|
||||
* how evaluation metrics for particular search queries get combined for the overall evaluation score.
|
||||
* Defaults to averaging over the partial results.
|
||||
*/
|
||||
default double combine(Collection<EvalQueryQuality> partialResults) {
|
||||
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
|
||||
}
|
@ -22,7 +22,7 @@ package org.elasticsearch.index.rankeval;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
@ -34,49 +34,45 @@ import java.util.Optional;
|
||||
|
||||
import javax.naming.directory.SearchResult;
|
||||
|
||||
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
|
||||
|
||||
/**
|
||||
* Evaluate mean reciprocal rank. By default documents with a rating equal or bigger
|
||||
* than 1 are considered to be "relevant" for the reciprocal rank calculation.
|
||||
* This value can be changes using the "relevant_rating_threshold" parameter.
|
||||
* Evaluates using mean reciprocal rank. By default documents with a rating
|
||||
* equal or bigger than 1 are considered to be "relevant" for the reciprocal
|
||||
* rank calculation. This value can be changes using the
|
||||
* "relevant_rating_threshold" parameter.
|
||||
*/
|
||||
public class MeanReciprocalRank implements RankedListQualityMetric {
|
||||
public class MeanReciprocalRank implements EvaluationMetric {
|
||||
|
||||
private static final int DEFAULT_RATING_THRESHOLD = 1;
|
||||
|
||||
public static final String NAME = "mean_reciprocal_rank";
|
||||
|
||||
/** ratings equal or above this value will be considered relevant. */
|
||||
private int relevantRatingThreshhold = 1;
|
||||
private final int relevantRatingThreshhold;
|
||||
|
||||
/**
|
||||
* Initializes maxAcceptableRank with 10
|
||||
*/
|
||||
public MeanReciprocalRank() {
|
||||
// use defaults
|
||||
this(DEFAULT_RATING_THRESHOLD);
|
||||
}
|
||||
|
||||
public MeanReciprocalRank(StreamInput in) throws IOException {
|
||||
this.relevantRatingThreshhold = in.readVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the rating threshold above which ratings are considered to be
|
||||
* "relevant" for this metric.
|
||||
*/
|
||||
public void setRelevantRatingThreshhold(int threshold) {
|
||||
public MeanReciprocalRank(int threshold) {
|
||||
if (threshold < 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"Relevant rating threshold for precision must be positive integer.");
|
||||
}
|
||||
|
||||
this.relevantRatingThreshhold = threshold;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the rating threshold above which ratings are considered to be
|
||||
* "relevant" for this metric. Defaults to 1.
|
||||
@ -119,13 +115,19 @@ public class MeanReciprocalRank implements RankedListQualityMetric {
|
||||
out.writeVInt(relevantRatingThreshhold);
|
||||
}
|
||||
|
||||
private static final ParseField RELEVANT_RATING_FIELD = new ParseField(
|
||||
"relevant_rating_threshold");
|
||||
private static final ObjectParser<MeanReciprocalRank, Void> PARSER = new ObjectParser<>(
|
||||
"reciprocal_rank", () -> new MeanReciprocalRank());
|
||||
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
|
||||
private static final ConstructingObjectParser<MeanReciprocalRank, Void> PARSER = new ConstructingObjectParser<>("reciprocal_rank",
|
||||
args -> {
|
||||
Integer optionalThreshold = (Integer) args[0];
|
||||
if (optionalThreshold == null) {
|
||||
return new MeanReciprocalRank();
|
||||
} else {
|
||||
return new MeanReciprocalRank(optionalThreshold);
|
||||
}
|
||||
});
|
||||
|
||||
static {
|
||||
PARSER.declareInt(MeanReciprocalRank::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
|
||||
}
|
||||
|
||||
public static MeanReciprocalRank fromXContent(XContentParser parser) {
|
||||
|
@ -22,7 +22,7 @@ package org.elasticsearch.index.rankeval;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
@ -34,7 +34,8 @@ import java.util.Optional;
|
||||
|
||||
import javax.naming.directory.SearchResult;
|
||||
|
||||
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsWithRatings;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
|
||||
|
||||
/**
|
||||
* Evaluate Precision of the search results. Documents without a rating are
|
||||
@ -42,15 +43,12 @@ import static org.elasticsearch.index.rankeval.RankedListQualityMetric.joinHitsW
|
||||
* considered to be "relevant" for the precision calculation. This value can be
|
||||
* changes using the "relevant_rating_threshold" parameter.
|
||||
*/
|
||||
public class PrecisionAtK implements RankedListQualityMetric {
|
||||
public class PrecisionAtK implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "precision";
|
||||
|
||||
private static final ParseField RELEVANT_RATING_FIELD = new ParseField(
|
||||
"relevant_rating_threshold");
|
||||
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
|
||||
private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled");
|
||||
private static final ObjectParser<PrecisionAtK, Void> PARSER = new ObjectParser<>(NAME,
|
||||
PrecisionAtK::new);
|
||||
|
||||
/**
|
||||
* This setting controls how unlabeled documents in the search hits are
|
||||
@ -58,29 +56,47 @@ public class PrecisionAtK implements RankedListQualityMetric {
|
||||
* as true or false positives. Set to 'false', they are treated as false
|
||||
* positives.
|
||||
*/
|
||||
private boolean ignoreUnlabeled = false;
|
||||
private final boolean ignoreUnlabeled;
|
||||
|
||||
/** ratings equal or above this value will be considered relevant. */
|
||||
private int relevantRatingThreshhold = 1;
|
||||
private final int relevantRatingThreshhold;
|
||||
|
||||
public PrecisionAtK(int threshold, boolean ignoreUnlabeled) {
|
||||
if (threshold < 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"Relevant rating threshold for precision must be positive integer.");
|
||||
}
|
||||
this.relevantRatingThreshhold = threshold;
|
||||
this.ignoreUnlabeled = ignoreUnlabeled;
|
||||
}
|
||||
|
||||
public PrecisionAtK() {
|
||||
// needed for supplier in parser
|
||||
this(1, false);
|
||||
}
|
||||
|
||||
|
||||
private static final ConstructingObjectParser<PrecisionAtK, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
args -> {
|
||||
Integer threshHold = (Integer) args[0];
|
||||
Boolean ignoreUnlabeled = (Boolean) args[1];
|
||||
return new PrecisionAtK(threshHold == null ? 1 : threshHold,
|
||||
ignoreUnlabeled == null ? false : ignoreUnlabeled);
|
||||
});
|
||||
|
||||
static {
|
||||
PARSER.declareInt(PrecisionAtK::setRelevantRatingThreshhold, RELEVANT_RATING_FIELD);
|
||||
PARSER.declareBoolean(PrecisionAtK::setIgnoreUnlabeled, IGNORE_UNLABELED_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
|
||||
PARSER.declareBoolean(optionalConstructorArg(), IGNORE_UNLABELED_FIELD);
|
||||
}
|
||||
|
||||
public PrecisionAtK(StreamInput in) throws IOException {
|
||||
relevantRatingThreshhold = in.readOptionalVInt();
|
||||
ignoreUnlabeled = in.readOptionalBoolean();
|
||||
PrecisionAtK(StreamInput in) throws IOException {
|
||||
relevantRatingThreshhold = in.readVInt();
|
||||
ignoreUnlabeled = in.readBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalVInt(relevantRatingThreshhold);
|
||||
out.writeOptionalBoolean(ignoreUnlabeled);
|
||||
out.writeVInt(relevantRatingThreshhold);
|
||||
out.writeBoolean(ignoreUnlabeled);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -88,18 +104,6 @@ public class PrecisionAtK implements RankedListQualityMetric {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the rating threshold above which ratings are considered to be
|
||||
* "relevant" for this metric.
|
||||
*/
|
||||
public void setRelevantRatingThreshhold(int threshold) {
|
||||
if (threshold < 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"Relevant rating threshold for precision must be positive integer.");
|
||||
}
|
||||
this.relevantRatingThreshhold = threshold;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the rating threshold above which ratings are considered to be
|
||||
* "relevant" for this metric. Defaults to 1.
|
||||
@ -108,13 +112,6 @@ public class PrecisionAtK implements RankedListQualityMetric {
|
||||
return relevantRatingThreshhold;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the 'ìgnore_unlabeled' parameter
|
||||
*/
|
||||
public void setIgnoreUnlabeled(boolean ignoreUnlabeled) {
|
||||
this.ignoreUnlabeled = ignoreUnlabeled;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the 'ìgnore_unlabeled' parameter
|
||||
*/
|
||||
|
@ -56,24 +56,19 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns parsers for {@link NamedWriteable} this plugin will use over the
|
||||
* transport protocol.
|
||||
*
|
||||
* Returns parsers for {@link NamedWriteable} objects that this plugin sends over the transport protocol.
|
||||
* @see NamedWriteableRegistry
|
||||
*/
|
||||
@Override
|
||||
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class,
|
||||
PrecisionAtK.NAME, PrecisionAtK::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class,
|
||||
MeanReciprocalRank.NAME, MeanReciprocalRank::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class,
|
||||
DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, PrecisionAtK.NAME,
|
||||
PrecisionAtK.Breakdown::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class,
|
||||
MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetails.class, PrecisionAtK.NAME, PrecisionAtK.Breakdown::new));
|
||||
namedWriteables
|
||||
.add(new NamedWriteableRegistry.Entry(MetricDetails.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Breakdown::new));
|
||||
return namedWriteables;
|
||||
}
|
||||
}
|
||||
|
@ -34,7 +34,9 @@ import org.elasticsearch.script.Script;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
@ -53,9 +55,9 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
|
||||
* Collection of query specifications, that is e.g. search request templates
|
||||
* to use for query translation.
|
||||
*/
|
||||
private Collection<RatedRequest> ratedRequests = new ArrayList<>();
|
||||
private final List<RatedRequest> ratedRequests;
|
||||
/** Definition of the quality metric, e.g. precision at N */
|
||||
private RankedListQualityMetric metric;
|
||||
private final EvaluationMetric metric;
|
||||
/** Maximum number of requests to execute in parallel. */
|
||||
private int maxConcurrentSearches = MAX_CONCURRENT_SEARCHES;
|
||||
/** Default max number of requests. */
|
||||
@ -63,7 +65,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
|
||||
/** optional: Templates to base test requests on */
|
||||
private Map<String, Script> templates = new HashMap<>();
|
||||
|
||||
public RankEvalSpec(Collection<RatedRequest> ratedRequests, RankedListQualityMetric metric,
|
||||
public RankEvalSpec(List<RatedRequest> ratedRequests, EvaluationMetric metric,
|
||||
Collection<ScriptWithId> templates) {
|
||||
if (ratedRequests == null || ratedRequests.size() < 1) {
|
||||
throw new IllegalStateException(
|
||||
@ -92,7 +94,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
|
||||
}
|
||||
}
|
||||
|
||||
public RankEvalSpec(Collection<RatedRequest> ratedRequests, RankedListQualityMetric metric) {
|
||||
public RankEvalSpec(List<RatedRequest> ratedRequests, EvaluationMetric metric) {
|
||||
this(ratedRequests, metric, null);
|
||||
}
|
||||
|
||||
@ -102,7 +104,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
|
||||
for (int i = 0; i < specSize; i++) {
|
||||
ratedRequests.add(new RatedRequest(in));
|
||||
}
|
||||
metric = in.readNamedWriteable(RankedListQualityMetric.class);
|
||||
metric = in.readNamedWriteable(EvaluationMetric.class);
|
||||
int size = in.readVInt();
|
||||
for (int i = 0; i < size; i++) {
|
||||
String key = in.readString();
|
||||
@ -128,13 +130,13 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
|
||||
}
|
||||
|
||||
/** Returns the metric to use for quality evaluation.*/
|
||||
public RankedListQualityMetric getMetric() {
|
||||
public EvaluationMetric getMetric() {
|
||||
return metric;
|
||||
}
|
||||
|
||||
/** Returns a list of intent to query translation specifications to evaluate. */
|
||||
public Collection<RatedRequest> getRatedRequests() {
|
||||
return ratedRequests;
|
||||
public List<RatedRequest> getRatedRequests() {
|
||||
return Collections.unmodifiableList(ratedRequests);
|
||||
}
|
||||
|
||||
/** Returns the template to base test requests on. */
|
||||
@ -160,8 +162,8 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<RankEvalSpec, Void> PARSER =
|
||||
new ConstructingObjectParser<>("rank_eval",
|
||||
a -> new RankEvalSpec((Collection<RatedRequest>) a[0],
|
||||
(RankedListQualityMetric) a[1], (Collection<ScriptWithId>) a[2]));
|
||||
a -> new RankEvalSpec((List<RatedRequest>) a[0],
|
||||
(EvaluationMetric) a[1], (Collection<ScriptWithId>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||
@ -169,7 +171,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
|
||||
} , REQUESTS_FIELD);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||
try {
|
||||
return RankedListQualityMetric.fromXContent(p);
|
||||
return EvaluationMetric.fromXContent(p);
|
||||
} catch (IOException ex) {
|
||||
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
|
||||
}
|
||||
|
@ -33,13 +33,24 @@ import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* A document ID and its rating for the query QA use case.
|
||||
* Represents a document (specified by its _index/_id) and its corresponding rating
|
||||
* with respect to a specific search query.
|
||||
* <p>
|
||||
* Json structure in a request:
|
||||
* <pre>
|
||||
* {
|
||||
* "_index": "my_index",
|
||||
* "_id": "doc1",
|
||||
* "rating": 0
|
||||
* }
|
||||
* </pre>
|
||||
*
|
||||
*/
|
||||
public class RatedDocument implements Writeable, ToXContentObject {
|
||||
|
||||
public static final ParseField RATING_FIELD = new ParseField("rating");
|
||||
public static final ParseField DOC_ID_FIELD = new ParseField("_id");
|
||||
public static final ParseField INDEX_FIELD = new ParseField("_index");
|
||||
static final ParseField RATING_FIELD = new ParseField("rating");
|
||||
static final ParseField DOC_ID_FIELD = new ParseField("_id");
|
||||
static final ParseField INDEX_FIELD = new ParseField("_index");
|
||||
|
||||
private static final ConstructingObjectParser<RatedDocument, Void> PARSER = new ConstructingObjectParser<>("rated_document",
|
||||
a -> new RatedDocument((String) a[0], (String) a[1], (Integer) a[2]));
|
||||
@ -50,23 +61,19 @@ public class RatedDocument implements Writeable, ToXContentObject {
|
||||
PARSER.declareInt(ConstructingObjectParser.constructorArg(), RATING_FIELD);
|
||||
}
|
||||
|
||||
private int rating;
|
||||
private DocumentKey key;
|
||||
private final int rating;
|
||||
private final DocumentKey key;
|
||||
|
||||
public RatedDocument(String index, String docId, int rating) {
|
||||
this(new DocumentKey(index, docId), rating);
|
||||
}
|
||||
|
||||
public RatedDocument(StreamInput in) throws IOException {
|
||||
this.key = new DocumentKey(in);
|
||||
this.rating = in.readVInt();
|
||||
}
|
||||
|
||||
public RatedDocument(DocumentKey ratedDocumentKey, int rating) {
|
||||
this.key = ratedDocumentKey;
|
||||
public RatedDocument(String index, String id, int rating) {
|
||||
this.key = new DocumentKey(index, id);
|
||||
this.rating = rating;
|
||||
}
|
||||
|
||||
RatedDocument(StreamInput in) throws IOException {
|
||||
this.key = new DocumentKey(in.readString(), in.readString());
|
||||
this.rating = in.readVInt();
|
||||
}
|
||||
|
||||
public DocumentKey getKey() {
|
||||
return this.key;
|
||||
}
|
||||
@ -76,7 +83,7 @@ public class RatedDocument implements Writeable, ToXContentObject {
|
||||
}
|
||||
|
||||
public String getDocID() {
|
||||
return key.getDocID();
|
||||
return key.getDocId();
|
||||
}
|
||||
|
||||
public int getRating() {
|
||||
@ -85,11 +92,12 @@ public class RatedDocument implements Writeable, ToXContentObject {
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
this.key.writeTo(out);
|
||||
out.writeString(key.getIndex());
|
||||
out.writeString(key.getDocId());
|
||||
out.writeVInt(rating);
|
||||
}
|
||||
|
||||
public static RatedDocument fromXContent(XContentParser parser) {
|
||||
static RatedDocument fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@ -97,7 +105,7 @@ public class RatedDocument implements Writeable, ToXContentObject {
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(INDEX_FIELD.getPreferredName(), key.getIndex());
|
||||
builder.field(DOC_ID_FIELD.getPreferredName(), key.getDocID());
|
||||
builder.field(DOC_ID_FIELD.getPreferredName(), key.getDocId());
|
||||
builder.field(RATING_FIELD.getPreferredName(), rating);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
@ -124,4 +132,55 @@ public class RatedDocument implements Writeable, ToXContentObject {
|
||||
public final int hashCode() {
|
||||
return Objects.hash(key, rating);
|
||||
}
|
||||
|
||||
/**
|
||||
* a joint document key consisting of the documents index and id
|
||||
*/
|
||||
static class DocumentKey {
|
||||
|
||||
private final String docId;
|
||||
private final String index;
|
||||
|
||||
DocumentKey(String index, String docId) {
|
||||
if (Strings.isNullOrEmpty(index)) {
|
||||
throw new IllegalArgumentException("Index must be set for each rated document");
|
||||
}
|
||||
if (Strings.isNullOrEmpty(docId)) {
|
||||
throw new IllegalArgumentException("DocId must be set for each rated document");
|
||||
}
|
||||
|
||||
this.index = index;
|
||||
this.docId = docId;
|
||||
}
|
||||
|
||||
String getIndex() {
|
||||
return index;
|
||||
}
|
||||
|
||||
String getDocId() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
DocumentKey other = (DocumentKey) obj;
|
||||
return Objects.equals(index, other.index) && Objects.equals(docId, other.docId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(index, docId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "{\"_index\":\"" + index + "\",\"_id\":\"" + docId + "\"}";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -157,11 +157,11 @@ public class TransportRankEvalAction
|
||||
private RatedRequest specification;
|
||||
private Map<String, EvalQueryQuality> requestDetails;
|
||||
private Map<String, Exception> errors;
|
||||
private RankedListQualityMetric metric;
|
||||
private EvaluationMetric metric;
|
||||
private AtomicInteger responseCounter;
|
||||
|
||||
public RankEvalActionListener(ActionListener<RankEvalResponse> listener,
|
||||
RankedListQualityMetric metric, RatedRequest specification,
|
||||
EvaluationMetric metric, RatedRequest specification,
|
||||
Map<String, EvalQueryQuality> details, Map<String, Exception> errors,
|
||||
AtomicInteger responseCounter) {
|
||||
this.listener = listener;
|
||||
|
@ -19,6 +19,7 @@
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
@ -36,20 +37,23 @@ import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.filterUnknownDocuments;
|
||||
import static org.elasticsearch.index.rankeval.EvaluationMetric.filterUnknownDocuments;
|
||||
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
|
||||
|
||||
|
||||
public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||
|
||||
/**
|
||||
* Assuming the docs are ranked in the following order:
|
||||
*
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
|
||||
* log_2(rank + 1)
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
|
||||
* -------------------------------------------------------------------------------------------
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
|
||||
* 3 | 3 | 7.0 | 2.0 | 3.5 4 | 0 | 0.0 | 2.321928094887362 | 0.0 5 | 1 | 1.0
|
||||
* | 2.584962500721156 | 0.38685280723454163 6 | 2 | 3.0 | 2.807354922057604
|
||||
* | 1.0686215613240666
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 |
|
||||
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
|
||||
* 3 | 3 | 7.0 | 2.0 | 3.5
|
||||
* 4 | 0 | 0.0 | 2.321928094887362 | 0.0
|
||||
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
|
||||
* 6 | 2 | 3.0 | 2.807354922057604 | 1.0686215613240666
|
||||
*
|
||||
* dcg = 13.84826362927298 (sum of last column)
|
||||
*/
|
||||
@ -69,17 +73,18 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||
* Check with normalization: to get the maximal possible dcg, sort documents by
|
||||
* relevance in descending order
|
||||
*
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
|
||||
* log_2(rank + 1)
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
|
||||
* ---------------------------------------------------------------------------------------
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
|
||||
* 3 | 2 | 3.0 | 2.0 | 1.5 4 | 2 | 3.0 | 2.321928094887362
|
||||
* | 1.2920296742201793 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163 6
|
||||
* | 0 | 0.0 | 2.807354922057604 | 0.0
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0
|
||||
* 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
|
||||
* 3 | 2 | 3.0 | 2.0 | 1.5
|
||||
* 4 | 2 | 3.0 | 2.321928094887362 | 1.2920296742201793
|
||||
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
|
||||
* 6 | 0 | 0.0 | 2.807354922057604 | 0.0
|
||||
*
|
||||
* idcg = 14.595390756454922 (sum of last column)
|
||||
*/
|
||||
dcg.setNormalize(true);
|
||||
dcg = new DiscountedCumulativeGain(true, null);
|
||||
assertEquals(13.84826362927298 / 14.595390756454922, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
|
||||
}
|
||||
|
||||
@ -87,12 +92,14 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||
* This tests metric when some documents in the search result don't have a
|
||||
* rating provided by the user.
|
||||
*
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
|
||||
* log_2(rank + 1)
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
|
||||
* -------------------------------------------------------------------------------------------
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
|
||||
* 3 | 3 | 7.0 | 2.0 | 3.5 4 | n/a | n/a | n/a | n/a 5 | 1 | 1.0
|
||||
* | 2.584962500721156 | 0.38685280723454163 6 | n/a | n/a | n/a | n/a
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 |
|
||||
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
|
||||
* 3 | 3 | 7.0 | 2.0 | 3.5
|
||||
* 4 | n/a | n/a | n/a | n/a
|
||||
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
|
||||
* 6 | n/a | n/a | n/a | n/a
|
||||
*
|
||||
* dcg = 12.779642067948913 (sum of last column)
|
||||
*/
|
||||
@ -118,16 +125,18 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||
* Check with normalization: to get the maximal possible dcg, sort documents by
|
||||
* relevance in descending order
|
||||
*
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
|
||||
* log_2(rank + 1)
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
|
||||
* ----------------------------------------------------------------------------------------
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
|
||||
* 3 | 2 | 3.0 | 2.0 | 1.5 4 | 1 | 1.0 | 2.321928094887362 | 0.43067655807339
|
||||
* 5 | n.a | n.a | n.a. | n.a. 6 | n.a | n.a | n.a | n.a
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0
|
||||
* 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
|
||||
* 3 | 2 | 3.0 | 2.0 | 1.5
|
||||
* 4 | 1 | 1.0 | 2.321928094887362 | 0.43067655807339
|
||||
* 5 | n.a | n.a | n.a. | n.a.
|
||||
* 6 | n.a | n.a | n.a | n.a
|
||||
*
|
||||
* idcg = 13.347184833073591 (sum of last column)
|
||||
*/
|
||||
dcg.setNormalize(true);
|
||||
dcg = new DiscountedCumulativeGain(true, null);
|
||||
assertEquals(12.779642067948913 / 13.347184833073591, dcg.evaluate("id", hits, rated).getQualityLevel(), 0.00001);
|
||||
}
|
||||
|
||||
@ -136,13 +145,15 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||
* documents than search hits because we restrict DCG to be calculated at the
|
||||
* fourth position
|
||||
*
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
|
||||
* log_2(rank + 1)
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
|
||||
* -------------------------------------------------------------------------------------------
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
|
||||
* 3 | 3 | 7.0 | 2.0 | 3.5 4 | n/a | n/a | n/a | n/a
|
||||
* ----------------------------------------------------------------- 5 | 1 | 1.0
|
||||
* | 2.584962500721156 | 0.38685280723454163 6 | n/a | n/a | n/a | n/a
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 |
|
||||
* 2 | 3.0 | 1.5849625007211563 | 1.8927892607143721
|
||||
* 3 | 3 | 7.0 | 2.0 | 3.5
|
||||
* 4 | n/a | n/a | n/a | n/a
|
||||
* -----------------------------------------------------------------
|
||||
* 5 | 1 | 1.0 | 2.584962500721156 | 0.38685280723454163
|
||||
* 6 | n/a | n/a | n/a | n/a
|
||||
*
|
||||
* dcg = 12.392789260714371 (sum of last column until position 4)
|
||||
*/
|
||||
@ -171,22 +182,24 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||
* Check with normalization: to get the maximal possible dcg, sort documents by
|
||||
* relevance in descending order
|
||||
*
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) /
|
||||
* log_2(rank + 1)
|
||||
* rank | rel_rank | 2^(rel_rank) - 1 | log_2(rank + 1) | (2^(rel_rank) - 1) / log_2(rank + 1)
|
||||
* ---------------------------------------------------------------------------------------
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
|
||||
* 3 | 2 | 3.0 | 2.0 | 1.5 4 | 1 | 1.0 | 2.321928094887362 | 0.43067655807339
|
||||
* 1 | 3 | 7.0 | 1.0 | 7.0
|
||||
* 2 | 3 | 7.0 | 1.5849625007211563 | 4.416508275000202
|
||||
* 3 | 2 | 3.0 | 2.0 | 1.5
|
||||
* 4 | 1 | 1.0 | 2.321928094887362 | 0.43067655807339
|
||||
* ---------------------------------------------------------------------------------------
|
||||
* 5 | n.a | n.a | n.a. | n.a. 6 | n.a | n.a | n.a | n.a
|
||||
* 5 | n.a | n.a | n.a. | n.a.
|
||||
* 6 | n.a | n.a | n.a | n.a
|
||||
*
|
||||
* idcg = 13.347184833073591 (sum of last column)
|
||||
*/
|
||||
dcg.setNormalize(true);
|
||||
dcg = new DiscountedCumulativeGain(true, null);
|
||||
assertEquals(12.392789260714371 / 13.347184833073591, dcg.evaluate("id", hits, ratedDocs).getQualityLevel(), 0.00001);
|
||||
}
|
||||
|
||||
public void testParseFromXContent() throws IOException {
|
||||
String xContent = " {\n" + " \"unknown_doc_rating\": 2,\n" + " \"normalize\": true\n" + "}";
|
||||
String xContent = " { \"unknown_doc_rating\": 2, \"normalize\": true }";
|
||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||
DiscountedCumulativeGain dcgAt = DiscountedCumulativeGain.fromXContent(parser);
|
||||
assertEquals(2, dcgAt.getUnknownDocRating().intValue());
|
||||
@ -217,29 +230,25 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
DiscountedCumulativeGain original = createTestItem();
|
||||
DiscountedCumulativeGain deserialized = RankEvalTestHelper.copy(original, DiscountedCumulativeGain::new);
|
||||
DiscountedCumulativeGain deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()),
|
||||
DiscountedCumulativeGain::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
DiscountedCumulativeGain testItem = createTestItem();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
RankEvalTestHelper.copy(testItem, DiscountedCumulativeGain::new));
|
||||
checkEqualsAndHashCode(createTestItem(), original -> {
|
||||
return new DiscountedCumulativeGain(original.getNormalize(), original.getUnknownDocRating());
|
||||
}, DiscountedCumulativeGainTests::mutateTestItem);
|
||||
}
|
||||
|
||||
private static DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) {
|
||||
boolean normalise = original.getNormalize();
|
||||
int unknownDocRating = original.getUnknownDocRating();
|
||||
DiscountedCumulativeGain gain = new DiscountedCumulativeGain();
|
||||
gain.setNormalize(normalise);
|
||||
gain.setUnknownDocRating(unknownDocRating);
|
||||
|
||||
List<Runnable> mutators = new ArrayList<>();
|
||||
mutators.add(() -> gain.setNormalize(!original.getNormalize()));
|
||||
mutators.add(() -> gain.setUnknownDocRating(randomValueOtherThan(unknownDocRating, () -> randomIntBetween(0, 10))));
|
||||
randomFrom(mutators).run();
|
||||
return gain;
|
||||
if (randomBoolean()) {
|
||||
return new DiscountedCumulativeGain(!original.getNormalize(), original.getUnknownDocRating());
|
||||
} else {
|
||||
return new DiscountedCumulativeGain(original.getNormalize(),
|
||||
randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,67 +0,0 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class DocumentKeyTests extends ESTestCase {
|
||||
|
||||
static DocumentKey createRandomRatedDocumentKey() {
|
||||
String index = randomAlphaOfLengthBetween(1, 10);
|
||||
String docId = randomAlphaOfLengthBetween(1, 10);
|
||||
return new DocumentKey(index, docId);
|
||||
}
|
||||
|
||||
public DocumentKey createTestItem() {
|
||||
return createRandomRatedDocumentKey();
|
||||
}
|
||||
|
||||
public DocumentKey mutateTestItem(DocumentKey original) {
|
||||
String index = original.getIndex();
|
||||
String docId = original.getDocID();
|
||||
switch (randomIntBetween(0, 1)) {
|
||||
case 0:
|
||||
index = index + "_";
|
||||
break;
|
||||
case 1:
|
||||
docId = docId + "_";
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("The test should only allow two parameters mutated");
|
||||
}
|
||||
return new DocumentKey(index, docId);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
DocumentKey testItem = createRandomRatedDocumentKey();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
new DocumentKey(testItem.getIndex(), testItem.getDocID()));
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
DocumentKey original = createTestItem();
|
||||
DocumentKey deserialized = RankEvalTestHelper.copy(original, DocumentKey::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
}
|
@ -20,22 +20,24 @@
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
|
||||
|
||||
public class EvalQueryQualityTests extends ESTestCase {
|
||||
|
||||
private static NamedWriteableRegistry namedWritableRegistry = new NamedWriteableRegistry(
|
||||
new RankEvalPlugin().getNamedWriteables());
|
||||
private static NamedWriteableRegistry namedWritableRegistry = new NamedWriteableRegistry(new RankEvalPlugin().getNamedWriteables());
|
||||
|
||||
public static EvalQueryQuality randomEvalQueryQuality() {
|
||||
List<DocumentKey> unknownDocs = new ArrayList<>();
|
||||
int numberOfUnknownDocs = randomInt(5);
|
||||
for (int i = 0; i < numberOfUnknownDocs; i++) {
|
||||
unknownDocs.add(DocumentKeyTests.createRandomRatedDocumentKey());
|
||||
unknownDocs.add(new DocumentKey(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
||||
}
|
||||
int numberOfSearchHits = randomInt(5);
|
||||
List<RatedSearchHit> ratedHits = new ArrayList<>();
|
||||
@ -54,17 +56,18 @@ public class EvalQueryQualityTests extends ESTestCase {
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
EvalQueryQuality original = randomEvalQueryQuality();
|
||||
EvalQueryQuality deserialized = RankEvalTestHelper.copy(original, EvalQueryQuality::new,
|
||||
namedWritableRegistry);
|
||||
EvalQueryQuality deserialized = copy(original);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
private static EvalQueryQuality copy(EvalQueryQuality original) throws IOException {
|
||||
return ESTestCase.copyWriteable(original, namedWritableRegistry, EvalQueryQuality::new);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
EvalQueryQuality testItem = randomEvalQueryQuality();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
RankEvalTestHelper.copy(testItem, EvalQueryQuality::new, namedWritableRegistry));
|
||||
checkEqualsAndHashCode(randomEvalQueryQuality(), EvalQueryQualityTests::copy, EvalQueryQualityTests::mutateTestItem);
|
||||
}
|
||||
|
||||
private static EvalQueryQuality mutateTestItem(EvalQueryQuality original) {
|
||||
|
@ -19,14 +19,15 @@
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.index.Index;
|
||||
import org.elasticsearch.index.rankeval.PrecisionTests.Rating;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchShardTarget;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
@ -38,21 +39,35 @@ import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
|
||||
public class ReciprocalRankTests extends ESTestCase {
|
||||
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
|
||||
|
||||
public class MeanReciprocalRankTests extends ESTestCase {
|
||||
|
||||
public void testParseFromXContent() throws IOException {
|
||||
String xContent = "{ }";
|
||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
||||
assertEquals(1, mrr.getRelevantRatingThreshold());
|
||||
}
|
||||
|
||||
xContent = "{ \"relevant_rating_threshold\": 2 }";
|
||||
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
|
||||
MeanReciprocalRank mrr = MeanReciprocalRank.fromXContent(parser);
|
||||
assertEquals(2, mrr.getRelevantRatingThreshold());
|
||||
}
|
||||
}
|
||||
|
||||
public void testMaxAcceptableRank() {
|
||||
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank();
|
||||
|
||||
int searchHits = randomIntBetween(1, 50);
|
||||
|
||||
SearchHit[] hits = createSearchHits(0, searchHits, "test");
|
||||
List<RatedDocument> ratedDocs = new ArrayList<>();
|
||||
int relevantAt = randomIntBetween(0, searchHits);
|
||||
for (int i = 0; i <= searchHits; i++) {
|
||||
if (i == relevantAt) {
|
||||
ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.RELEVANT.ordinal()));
|
||||
ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.RELEVANT.ordinal()));
|
||||
} else {
|
||||
ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.IRRELEVANT.ordinal()));
|
||||
ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.IRRELEVANT.ordinal()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,9 +91,9 @@ public class ReciprocalRankTests extends ESTestCase {
|
||||
int relevantAt = randomIntBetween(0, 9);
|
||||
for (int i = 0; i <= 20; i++) {
|
||||
if (i == relevantAt) {
|
||||
ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.RELEVANT.ordinal()));
|
||||
ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.RELEVANT.ordinal()));
|
||||
} else {
|
||||
ratedDocs.add(new RatedDocument("test", Integer.toString(i), Rating.IRRELEVANT.ordinal()));
|
||||
ratedDocs.add(new RatedDocument("test", Integer.toString(i), TestRatingEnum.IRRELEVANT.ordinal()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -101,8 +116,7 @@ public class ReciprocalRankTests extends ESTestCase {
|
||||
rated.add(new RatedDocument("test", "4", 4));
|
||||
SearchHit[] hits = createSearchHits(0, 5, "test");
|
||||
|
||||
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank();
|
||||
reciprocalRank.setRelevantRatingThreshhold(2);
|
||||
MeanReciprocalRank reciprocalRank = new MeanReciprocalRank(2);
|
||||
EvalQueryQuality evaluation = reciprocalRank.evaluate("id", hits, rated);
|
||||
assertEquals((double) 1 / 3, evaluation.getQualityLevel(), 0.00001);
|
||||
assertEquals(3, ((MeanReciprocalRank.Breakdown) evaluation.getMetricDetails()).getFirstRelevantRank());
|
||||
@ -153,35 +167,31 @@ public class ReciprocalRankTests extends ESTestCase {
|
||||
}
|
||||
|
||||
private static MeanReciprocalRank createTestItem() {
|
||||
MeanReciprocalRank testItem = new MeanReciprocalRank();
|
||||
testItem.setRelevantRatingThreshhold(randomIntBetween(0, 20));
|
||||
return testItem;
|
||||
return new MeanReciprocalRank(randomIntBetween(0, 20));
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
MeanReciprocalRank original = createTestItem();
|
||||
|
||||
MeanReciprocalRank deserialized = RankEvalTestHelper.copy(original, MeanReciprocalRank::new);
|
||||
MeanReciprocalRank deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()),
|
||||
MeanReciprocalRank::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
MeanReciprocalRank testItem = createTestItem();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
RankEvalTestHelper.copy(testItem, MeanReciprocalRank::new));
|
||||
checkEqualsAndHashCode(createTestItem(), MeanReciprocalRankTests::copy, MeanReciprocalRankTests::mutate);
|
||||
}
|
||||
|
||||
private static MeanReciprocalRank mutateTestItem(MeanReciprocalRank testItem) {
|
||||
int relevantThreshold = testItem.getRelevantRatingThreshold();
|
||||
MeanReciprocalRank rank = new MeanReciprocalRank();
|
||||
rank.setRelevantRatingThreshhold(randomValueOtherThan(relevantThreshold, () -> randomIntBetween(0, 10)));
|
||||
return rank;
|
||||
private static MeanReciprocalRank copy(MeanReciprocalRank testItem) {
|
||||
return new MeanReciprocalRank(testItem.getRelevantRatingThreshold());
|
||||
}
|
||||
|
||||
private static MeanReciprocalRank mutate(MeanReciprocalRank testItem) {
|
||||
return new MeanReciprocalRank(randomValueOtherThan(testItem.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)));
|
||||
}
|
||||
|
||||
public void testInvalidRelevantThreshold() {
|
||||
MeanReciprocalRank prez = new MeanReciprocalRank();
|
||||
expectThrows(IllegalArgumentException.class, () -> prez.setRelevantRatingThreshhold(-1));
|
||||
expectThrows(IllegalArgumentException.class, () -> new MeanReciprocalRank(-1));
|
||||
}
|
||||
}
|
@ -19,6 +19,7 @@
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
@ -38,11 +39,13 @@ import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
|
||||
public class PrecisionTests extends ESTestCase {
|
||||
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
|
||||
|
||||
public class PrecisionAtKTests extends ESTestCase {
|
||||
|
||||
public void testPrecisionAtFiveCalculation() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal()));
|
||||
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
|
||||
assertEquals(1, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(1, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
@ -51,11 +54,11 @@ public class PrecisionTests extends ESTestCase {
|
||||
|
||||
public void testPrecisionAtFiveIgnoreOneResult() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal()));
|
||||
rated.add(new RatedDocument("test", "1", Rating.RELEVANT.ordinal()));
|
||||
rated.add(new RatedDocument("test", "2", Rating.RELEVANT.ordinal()));
|
||||
rated.add(new RatedDocument("test", "3", Rating.RELEVANT.ordinal()));
|
||||
rated.add(new RatedDocument("test", "4", Rating.IRRELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "1", TestRatingEnum.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "2", TestRatingEnum.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "3", TestRatingEnum.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "4", TestRatingEnum.IRRELEVANT.ordinal()));
|
||||
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated);
|
||||
assertEquals((double) 4 / 5, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(4, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
@ -69,13 +72,12 @@ public class PrecisionTests extends ESTestCase {
|
||||
*/
|
||||
public void testPrecisionAtFiveRelevanceThreshold() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(new RatedDocument("test", "0", 0));
|
||||
rated.add(new RatedDocument("test", "1", 1));
|
||||
rated.add(new RatedDocument("test", "2", 2));
|
||||
rated.add(new RatedDocument("test", "3", 3));
|
||||
rated.add(new RatedDocument("test", "4", 4));
|
||||
PrecisionAtK precisionAtN = new PrecisionAtK();
|
||||
precisionAtN.setRelevantRatingThreshhold(2);
|
||||
rated.add(createRatedDoc("test", "0", 0));
|
||||
rated.add(createRatedDoc("test", "1", 1));
|
||||
rated.add(createRatedDoc("test", "2", 2));
|
||||
rated.add(createRatedDoc("test", "3", 3));
|
||||
rated.add(createRatedDoc("test", "4", 4));
|
||||
PrecisionAtK precisionAtN = new PrecisionAtK(2, false);
|
||||
EvalQueryQuality evaluated = precisionAtN.evaluate("id", toSearchHits(rated, "test"), rated);
|
||||
assertEquals((double) 3 / 5, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
@ -84,11 +86,11 @@ public class PrecisionTests extends ESTestCase {
|
||||
|
||||
public void testPrecisionAtFiveCorrectIndex() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(new RatedDocument("test_other", "0", Rating.RELEVANT.ordinal()));
|
||||
rated.add(new RatedDocument("test_other", "1", Rating.RELEVANT.ordinal()));
|
||||
rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal()));
|
||||
rated.add(new RatedDocument("test", "1", Rating.RELEVANT.ordinal()));
|
||||
rated.add(new RatedDocument("test", "2", Rating.IRRELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test_other", "0", TestRatingEnum.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test_other", "1", TestRatingEnum.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "1", TestRatingEnum.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "2", TestRatingEnum.IRRELEVANT.ordinal()));
|
||||
// the following search hits contain only the last three documents
|
||||
EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated.subList(2, 5), "test"), rated);
|
||||
assertEquals((double) 2 / 3, evaluated.getQualityLevel(), 0.00001);
|
||||
@ -98,8 +100,8 @@ public class PrecisionTests extends ESTestCase {
|
||||
|
||||
public void testIgnoreUnlabeled() {
|
||||
List<RatedDocument> rated = new ArrayList<>();
|
||||
rated.add(new RatedDocument("test", "0", Rating.RELEVANT.ordinal()));
|
||||
rated.add(new RatedDocument("test", "1", Rating.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "0", TestRatingEnum.RELEVANT.ordinal()));
|
||||
rated.add(createRatedDoc("test", "1", TestRatingEnum.RELEVANT.ordinal()));
|
||||
// add an unlabeled search hit
|
||||
SearchHit[] searchHits = Arrays.copyOf(toSearchHits(rated, "test"), 3);
|
||||
searchHits[2] = new SearchHit(2, "2", new Text("testtype"), Collections.emptyMap());
|
||||
@ -111,8 +113,7 @@ public class PrecisionTests extends ESTestCase {
|
||||
assertEquals(3, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
|
||||
|
||||
// also try with setting `ignore_unlabeled`
|
||||
PrecisionAtK prec = new PrecisionAtK();
|
||||
prec.setIgnoreUnlabeled(true);
|
||||
PrecisionAtK prec = new PrecisionAtK(1, true);
|
||||
evaluated = prec.evaluate("id", searchHits, rated);
|
||||
assertEquals((double) 2 / 2, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(2, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
@ -131,8 +132,7 @@ public class PrecisionTests extends ESTestCase {
|
||||
assertEquals(5, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRetrieved());
|
||||
|
||||
// also try with setting `ignore_unlabeled`
|
||||
PrecisionAtK prec = new PrecisionAtK();
|
||||
prec.setIgnoreUnlabeled(true);
|
||||
PrecisionAtK prec = new PrecisionAtK(1, true);
|
||||
evaluated = prec.evaluate("id", hits, Collections.emptyList());
|
||||
assertEquals(0.0d, evaluated.getQualityLevel(), 0.00001);
|
||||
assertEquals(0, ((PrecisionAtK.Breakdown) evaluated.getMetricDetails()).getRelevantRetrieved());
|
||||
@ -158,16 +158,11 @@ public class PrecisionTests extends ESTestCase {
|
||||
|
||||
public void testInvalidRelevantThreshold() {
|
||||
PrecisionAtK prez = new PrecisionAtK();
|
||||
expectThrows(IllegalArgumentException.class, () -> prez.setRelevantRatingThreshhold(-1));
|
||||
expectThrows(IllegalArgumentException.class, () -> new PrecisionAtK(-1, false));
|
||||
}
|
||||
|
||||
public static PrecisionAtK createTestItem() {
|
||||
PrecisionAtK precision = new PrecisionAtK();
|
||||
if (randomBoolean()) {
|
||||
precision.setRelevantRatingThreshhold(randomIntBetween(0, 10));
|
||||
}
|
||||
precision.setIgnoreUnlabeled(randomBoolean());
|
||||
return precision;
|
||||
return new PrecisionAtK(randomIntBetween(0, 10), randomBoolean());
|
||||
}
|
||||
|
||||
public void testXContentRoundtrip() throws IOException {
|
||||
@ -186,29 +181,28 @@ public class PrecisionTests extends ESTestCase {
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
PrecisionAtK original = createTestItem();
|
||||
PrecisionAtK deserialized = RankEvalTestHelper.copy(original, PrecisionAtK::new);
|
||||
PrecisionAtK deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()),
|
||||
PrecisionAtK::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
PrecisionAtK testItem = createTestItem();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), RankEvalTestHelper.copy(testItem, PrecisionAtK::new));
|
||||
checkEqualsAndHashCode(createTestItem(), PrecisionAtKTests::copy, PrecisionAtKTests::mutate);
|
||||
}
|
||||
|
||||
private static PrecisionAtK mutateTestItem(PrecisionAtK original) {
|
||||
boolean ignoreUnlabeled = original.getIgnoreUnlabeled();
|
||||
int relevantThreshold = original.getRelevantRatingThreshold();
|
||||
PrecisionAtK precision = new PrecisionAtK();
|
||||
precision.setIgnoreUnlabeled(ignoreUnlabeled);
|
||||
precision.setRelevantRatingThreshhold(relevantThreshold);
|
||||
private static PrecisionAtK copy(PrecisionAtK original) {
|
||||
return new PrecisionAtK(original.getRelevantRatingThreshold(), original.getIgnoreUnlabeled());
|
||||
}
|
||||
|
||||
List<Runnable> mutators = new ArrayList<>();
|
||||
mutators.add(() -> precision.setIgnoreUnlabeled(!ignoreUnlabeled));
|
||||
mutators.add(() -> precision.setRelevantRatingThreshhold(randomValueOtherThan(relevantThreshold, () -> randomIntBetween(0, 10))));
|
||||
randomFrom(mutators).run();
|
||||
return precision;
|
||||
private static PrecisionAtK mutate(PrecisionAtK original) {
|
||||
if (randomBoolean()) {
|
||||
return new PrecisionAtK(original.getRelevantRatingThreshold(), !original.getIgnoreUnlabeled());
|
||||
} else {
|
||||
return new PrecisionAtK(randomValueOtherThan(original.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)),
|
||||
original.getIgnoreUnlabeled());
|
||||
}
|
||||
}
|
||||
|
||||
private static SearchHit[] toSearchHits(List<RatedDocument> rated, String index) {
|
||||
@ -220,7 +214,7 @@ public class PrecisionTests extends ESTestCase {
|
||||
return hits;
|
||||
}
|
||||
|
||||
public enum Rating {
|
||||
IRRELEVANT, RELEVANT;
|
||||
private static RatedDocument createRatedDoc(String index, String id, int rating) {
|
||||
return new RatedDocument(index, id, rating);
|
||||
}
|
||||
}
|
@ -22,7 +22,6 @@ package org.elasticsearch.index.rankeval;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.index.rankeval.PrecisionTests.Rating;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
@ -35,7 +34,7 @@ import java.util.List;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.index.rankeval.RankedListQualityMetric.filterUnknownDocuments;
|
||||
import static org.elasticsearch.index.rankeval.EvaluationMetric.filterUnknownDocuments;
|
||||
|
||||
public class RankEvalRequestIT extends ESIntegTestCase {
|
||||
@Override
|
||||
@ -82,8 +81,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
||||
|
||||
specifications.add(berlinRequest);
|
||||
|
||||
PrecisionAtK metric = new PrecisionAtK();
|
||||
metric.setIgnoreUnlabeled(true);
|
||||
PrecisionAtK metric = new PrecisionAtK(1, true);
|
||||
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
|
||||
@ -106,7 +104,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
||||
if (id.equals("1") || id.equals("6")) {
|
||||
assertFalse(hit.getRating().isPresent());
|
||||
} else {
|
||||
assertEquals(Rating.RELEVANT.ordinal(), hit.getRating().get().intValue());
|
||||
assertEquals(TestRatingEnum.RELEVANT.ordinal(), hit.getRating().get().intValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -117,7 +115,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
||||
for (RatedSearchHit hit : hitsAndRatings) {
|
||||
String id = hit.getSearchHit().getId();
|
||||
if (id.equals("1")) {
|
||||
assertEquals(Rating.RELEVANT.ordinal(), hit.getRating().get().intValue());
|
||||
assertEquals(TestRatingEnum.RELEVANT.ordinal(), hit.getRating().get().intValue());
|
||||
} else {
|
||||
assertFalse(hit.getRating().isPresent());
|
||||
}
|
||||
@ -167,7 +165,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
||||
private static List<RatedDocument> createRelevant(String... docs) {
|
||||
List<RatedDocument> relevant = new ArrayList<>();
|
||||
for (String doc : docs) {
|
||||
relevant.add(new RatedDocument("test", doc, Rating.RELEVANT.ordinal()));
|
||||
relevant.add(new RatedDocument("test", doc, TestRatingEnum.RELEVANT.ordinal()));
|
||||
}
|
||||
return relevant;
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.rankeval.RatedDocument.DocumentKey;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -43,7 +44,7 @@ public class RankEvalResponseTests extends ESTestCase {
|
||||
int numberOfUnknownDocs = randomIntBetween(0, 5);
|
||||
List<DocumentKey> unknownDocs = new ArrayList<>(numberOfUnknownDocs);
|
||||
for (int d = 0; d < numberOfUnknownDocs; d++) {
|
||||
unknownDocs.add(DocumentKeyTests.createRandomRatedDocumentKey());
|
||||
unknownDocs.add(new DocumentKey(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
||||
}
|
||||
EvalQueryQuality evalQuality = new EvalQueryQuality(id,
|
||||
randomDoubleBetween(0.0, 1.0, true));
|
||||
@ -65,12 +66,9 @@ public class RankEvalResponseTests extends ESTestCase {
|
||||
try (StreamInput in = output.bytes().streamInput()) {
|
||||
RankEvalResponse deserializedResponse = new RankEvalResponse();
|
||||
deserializedResponse.readFrom(in);
|
||||
assertEquals(randomResponse.getQualityLevel(),
|
||||
deserializedResponse.getQualityLevel(), Double.MIN_VALUE);
|
||||
assertEquals(randomResponse.getPartialResults(),
|
||||
deserializedResponse.getPartialResults());
|
||||
assertEquals(randomResponse.getFailures().keySet(),
|
||||
deserializedResponse.getFailures().keySet());
|
||||
assertEquals(randomResponse.getQualityLevel(), deserializedResponse.getQualityLevel(), Double.MIN_VALUE);
|
||||
assertEquals(randomResponse.getPartialResults(), deserializedResponse.getPartialResults());
|
||||
assertEquals(randomResponse.getFailures().keySet(), deserializedResponse.getFailures().keySet());
|
||||
assertNotSame(randomResponse, deserializedResponse);
|
||||
assertEquals(-1, in.read());
|
||||
}
|
||||
|
@ -45,6 +45,8 @@ import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
|
||||
|
||||
public class RankEvalSpecTests extends ESTestCase {
|
||||
|
||||
private static <T> List<T> randomList(Supplier<T> randomSupplier) {
|
||||
@ -57,9 +59,9 @@ public class RankEvalSpecTests extends ESTestCase {
|
||||
}
|
||||
|
||||
private static RankEvalSpec createTestItem() throws IOException {
|
||||
RankedListQualityMetric metric;
|
||||
EvaluationMetric metric;
|
||||
if (randomBoolean()) {
|
||||
metric = PrecisionTests.createTestItem();
|
||||
metric = PrecisionAtKTests.createTestItem();
|
||||
} else {
|
||||
metric = DiscountedCumulativeGainTests.createTestItem();
|
||||
}
|
||||
@ -111,41 +113,30 @@ public class RankEvalSpecTests extends ESTestCase {
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
RankEvalSpec original = createTestItem();
|
||||
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGain.NAME,
|
||||
DiscountedCumulativeGain::new));
|
||||
namedWriteables
|
||||
.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
|
||||
|
||||
RankEvalSpec deserialized = RankEvalTestHelper.copy(original, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables));
|
||||
RankEvalSpec deserialized = copy(original);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
RankEvalSpec testItem = createTestItem();
|
||||
|
||||
private static RankEvalSpec copy(RankEvalSpec original) throws IOException {
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, DiscountedCumulativeGain.NAME,
|
||||
DiscountedCumulativeGain::new));
|
||||
namedWriteables
|
||||
.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
|
||||
|
||||
RankEvalSpec mutant = RankEvalTestHelper.copy(testItem, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables));
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(mutant),
|
||||
RankEvalTestHelper.copy(testItem, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables)));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
|
||||
return ESTestCase.copyWriteable(original, new NamedWriteableRegistry(namedWriteables), RankEvalSpec::new);
|
||||
}
|
||||
|
||||
private static RankEvalSpec mutateTestItem(RankEvalSpec mutant) {
|
||||
Collection<RatedRequest> ratedRequests = mutant.getRatedRequests();
|
||||
RankedListQualityMetric metric = mutant.getMetric();
|
||||
Map<String, Script> templates = mutant.getTemplates();
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
checkEqualsAndHashCode(createTestItem(), RankEvalSpecTests::copy, RankEvalSpecTests::mutateTestItem);
|
||||
}
|
||||
|
||||
private static RankEvalSpec mutateTestItem(RankEvalSpec original) {
|
||||
List<RatedRequest> ratedRequests = new ArrayList<>(original.getRatedRequests());
|
||||
EvaluationMetric metric = original.getMetric();
|
||||
Map<String, Script> templates = original.getTemplates();
|
||||
|
||||
int mutate = randomIntBetween(0, 2);
|
||||
switch (mutate) {
|
||||
@ -177,7 +168,7 @@ public class RankEvalSpecTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public void testMissingRatedRequestsFailsParsing() {
|
||||
RankedListQualityMetric metric = new PrecisionAtK();
|
||||
EvaluationMetric metric = new PrecisionAtK();
|
||||
expectThrows(IllegalStateException.class, () -> new RankEvalSpec(new ArrayList<>(), metric));
|
||||
expectThrows(IllegalStateException.class, () -> new RankEvalSpec(null, metric));
|
||||
}
|
||||
@ -189,7 +180,7 @@ public class RankEvalSpecTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public void testMissingTemplateAndSearchRequestFailsParsing() {
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("key", "value");
|
||||
|
||||
|
@ -1,94 +0,0 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotSame;
|
||||
import static org.junit.Assert.assertThat;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
// TODO replace by infra from ESTestCase
|
||||
public class RankEvalTestHelper {
|
||||
|
||||
public static <T> void testHashCodeAndEquals(T testItem, T mutation, T secondCopy) {
|
||||
assertFalse("testItem is equal to null", testItem.equals(null));
|
||||
assertFalse("testItem is equal to incompatible type", testItem.equals(""));
|
||||
assertTrue("testItem is not equal to self", testItem.equals(testItem));
|
||||
assertThat("same testItem's hashcode returns different values if called multiple times",
|
||||
testItem.hashCode(), equalTo(testItem.hashCode()));
|
||||
|
||||
assertThat("different testItem should not be equal", mutation, not(equalTo(testItem)));
|
||||
|
||||
assertNotSame("testItem copy is not same as original", testItem, secondCopy);
|
||||
assertTrue("testItem is not equal to its copy", testItem.equals(secondCopy));
|
||||
assertTrue("equals is not symmetric", secondCopy.equals(testItem));
|
||||
assertThat("testItem copy's hashcode is different from original hashcode",
|
||||
secondCopy.hashCode(), equalTo(testItem.hashCode()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Make a deep copy of an object by running it through a BytesStreamOutput
|
||||
*
|
||||
* @param original
|
||||
* the original object
|
||||
* @param reader
|
||||
* a function able to create a new copy of this type
|
||||
* @return a new copy of the original object
|
||||
*/
|
||||
public static <T extends Writeable> T copy(T original, Writeable.Reader<T> reader)
|
||||
throws IOException {
|
||||
return copy(original, reader, new NamedWriteableRegistry(Collections.emptyList()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Make a deep copy of an object by running it through a BytesStreamOutput
|
||||
*
|
||||
* @param original
|
||||
* the original object
|
||||
* @param reader
|
||||
* a function able to create a new copy of this type
|
||||
* @param namedWriteableRegistry
|
||||
* must be non-empty if the object itself or nested object
|
||||
* implement {@link NamedWriteable}
|
||||
* @return a new copy of the original object
|
||||
*/
|
||||
public static <T extends Writeable> T copy(T original, Writeable.Reader<T> reader,
|
||||
NamedWriteableRegistry namedWriteableRegistry) throws IOException {
|
||||
try (BytesStreamOutput output = new BytesStreamOutput()) {
|
||||
original.writeTo(output);
|
||||
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(),
|
||||
namedWriteableRegistry)) {
|
||||
return reader.read(in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -19,6 +19,7 @@
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
@ -27,15 +28,14 @@ import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
|
||||
|
||||
public class RatedDocumentTests extends ESTestCase {
|
||||
|
||||
public static RatedDocument createRatedDocument() {
|
||||
String index = randomAlphaOfLength(10);
|
||||
String docId = randomAlphaOfLength(10);
|
||||
int rating = randomInt();
|
||||
|
||||
return new RatedDocument(index, docId, rating);
|
||||
return new RatedDocument(randomAlphaOfLength(10), randomAlphaOfLength(10), randomInt());
|
||||
}
|
||||
|
||||
public void testXContentParsing() throws IOException {
|
||||
@ -52,22 +52,17 @@ public class RatedDocumentTests extends ESTestCase {
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
RatedDocument original = createRatedDocument();
|
||||
RatedDocument deserialized = RankEvalTestHelper.copy(original, RatedDocument::new);
|
||||
RatedDocument deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()),
|
||||
RatedDocument::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
RatedDocument testItem = createRatedDocument();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem), RankEvalTestHelper.copy(testItem, RatedDocument::new));
|
||||
}
|
||||
|
||||
public void testInvalidParsing() {
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedDocument(null, "abc", 10));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("", "abc", 10));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", "", 10));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", null, 10));
|
||||
checkEqualsAndHashCode(createRatedDocument(), original -> {
|
||||
return new RatedDocument(original.getIndex(), original.getDocID(), original.getRating());
|
||||
}, RatedDocumentTests::mutateTestItem);
|
||||
}
|
||||
|
||||
private static RatedDocument mutateTestItem(RatedDocument original) {
|
||||
|
@ -47,6 +47,7 @@ import java.util.stream.Stream;
|
||||
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
|
||||
|
||||
public class RatedRequestsTests extends ESTestCase {
|
||||
|
||||
@ -140,32 +141,26 @@ public class RatedRequestsTests extends ESTestCase {
|
||||
for (int i = 0; i < size; i++) {
|
||||
indices.add(randomAlphaOfLengthBetween(0, 50));
|
||||
}
|
||||
|
||||
RatedRequest original = createTestItem(indices, randomBoolean());
|
||||
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
|
||||
|
||||
RatedRequest deserialized = RankEvalTestHelper.copy(original, RatedRequest::new, new NamedWriteableRegistry(namedWriteables));
|
||||
RatedRequest deserialized = copy(original);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
private static RatedRequest copy(RatedRequest original) throws IOException {
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
|
||||
return ESTestCase.copyWriteable(original, new NamedWriteableRegistry(namedWriteables), RatedRequest::new);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
List<String> indices = new ArrayList<>();
|
||||
int size = randomIntBetween(0, 20);
|
||||
for (int i = 0; i < size; i++) {
|
||||
indices.add(randomAlphaOfLengthBetween(0, 50));
|
||||
}
|
||||
|
||||
RatedRequest testItem = createTestItem(indices, randomBoolean());
|
||||
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
|
||||
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
RankEvalTestHelper.copy(testItem, RatedRequest::new, new NamedWriteableRegistry(namedWriteables)));
|
||||
checkEqualsAndHashCode(createTestItem(indices, randomBoolean()), RatedRequestsTests::copy, RatedRequestsTests::mutateTestItem);
|
||||
}
|
||||
|
||||
private static RatedRequest mutateTestItem(RatedRequest original) {
|
||||
@ -220,8 +215,7 @@ public class RatedRequestsTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public void testDuplicateRatedDocThrowsException() {
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1),
|
||||
new RatedDocument(new DocumentKey("index1", "id1"), 5));
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1), new RatedDocument("index1", "id1", 5));
|
||||
|
||||
// search request set, no summary fields
|
||||
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
|
||||
@ -237,45 +231,45 @@ public class RatedRequestsTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public void testNullSummaryFieldsTreatment() {
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
|
||||
RatedRequest request = new RatedRequest("id", ratedDocs, new SearchSourceBuilder());
|
||||
expectThrows(IllegalArgumentException.class, () -> request.setSummaryFields(null));
|
||||
}
|
||||
|
||||
public void testNullParamsTreatment() {
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
|
||||
RatedRequest request = new RatedRequest("id", ratedDocs, new SearchSourceBuilder(), null, null);
|
||||
assertNotNull(request.getParams());
|
||||
}
|
||||
|
||||
public void testSettingParamsAndRequestThrows() {
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("key", "value");
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder(), params, null));
|
||||
}
|
||||
|
||||
public void testSettingNeitherParamsNorRequestThrows() {
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, null));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, new HashMap<>(), "templateId"));
|
||||
}
|
||||
|
||||
public void testSettingParamsWithoutTemplateIdThrows() {
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("key", "value");
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, params, null));
|
||||
}
|
||||
|
||||
public void testSettingTemplateIdAndRequestThrows() {
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
|
||||
expectThrows(IllegalArgumentException.class,
|
||||
() -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder(), null, "templateId"));
|
||||
}
|
||||
|
||||
public void testSettingTemplateIdNoParamsThrows() {
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument(new DocumentKey("index1", "id1"), 1));
|
||||
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, null, "templateId"));
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,7 @@
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
@ -27,6 +28,8 @@ import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
|
||||
|
||||
public class RatedSearchHitTests extends ESTestCase {
|
||||
|
||||
public static RatedSearchHit randomRatedSearchHit() {
|
||||
@ -57,15 +60,18 @@ public class RatedSearchHitTests extends ESTestCase {
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
RatedSearchHit original = randomRatedSearchHit();
|
||||
RatedSearchHit deserialized = RankEvalTestHelper.copy(original, RatedSearchHit::new);
|
||||
RatedSearchHit deserialized = copy(original);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
RatedSearchHit testItem = randomRatedSearchHit();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
RankEvalTestHelper.copy(testItem, RatedSearchHit::new));
|
||||
checkEqualsAndHashCode(randomRatedSearchHit(), RatedSearchHitTests::copy, RatedSearchHitTests::mutateTestItem);
|
||||
}
|
||||
|
||||
private static RatedSearchHit copy(RatedSearchHit original) throws IOException {
|
||||
return ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()), RatedSearchHit::new);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,24 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
enum TestRatingEnum {
|
||||
IRRELEVANT, RELEVANT;
|
||||
}
|
@ -35,4 +35,4 @@
|
||||
- match: { rank_eval.details.amsterdam_query.unknown_docs: [ ]}
|
||||
- match: { rank_eval.details.amsterdam_query.metric_details: {"relevant_docs_retrieved": 1, "docs_retrieved": 1}}
|
||||
|
||||
- is_true: rank_eval.failures.invalid_queryy
|
||||
- is_true: rank_eval.failures.invalid_query
|
||||
|
Loading…
x
Reference in New Issue
Block a user