Merge pull request #20140 from MaineC/feature/rank-eval-roundtrip-testing

Add roundtrip xcontent tests to rank eval implementation
This commit is contained in:
Isabel Drost-Fromm 2016-09-07 15:04:36 +02:00 committed by GitHub
commit de959ee6ed
18 changed files with 486 additions and 91 deletions

View File

@ -35,6 +35,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class DiscountedCumulativeGainAt extends RankedListQualityMetric {
@ -48,24 +49,6 @@ public class DiscountedCumulativeGainAt extends RankedListQualityMetric {
public static final String NAME = "dcg_at_n";
private static final double LOG2 = Math.log(2.0);
public DiscountedCumulativeGainAt(StreamInput in) throws IOException {
position = in.readInt();
normalize = in.readBoolean();
unknownDocRating = in.readOptionalVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(position);
out.writeBoolean(normalize);
out.writeOptionalVInt(unknownDocRating);
}
@Override
public String getWriteableName() {
return NAME;
}
/**
* Initialises position with 10
* */
@ -83,6 +66,36 @@ public class DiscountedCumulativeGainAt extends RankedListQualityMetric {
this.position = position;
}
/**
* @param position number of top results to check against a given set of relevant results. Must be positive.
* @param normalize If set to true, dcg will be normalized (ndcg)
* See https://en.wikipedia.org/wiki/Discounted_cumulative_gain
* @param unknownDocRating the rating for docs the user hasn't supplied an explicit rating for
* */
public DiscountedCumulativeGainAt(int position, boolean normalize, Integer unknownDocRating) {
this(position);
this.normalize = normalize;
this.unknownDocRating = unknownDocRating;
}
public DiscountedCumulativeGainAt(StreamInput in) throws IOException {
this(in.readInt());
normalize = in.readBoolean();
unknownDocRating = in.readOptionalVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(position);
out.writeBoolean(normalize);
out.writeOptionalVInt(unknownDocRating);
}
@Override
public String getWriteableName() {
return NAME;
}
/**
* Return number of search results to check for quality metric.
*/
@ -184,6 +197,7 @@ public class DiscountedCumulativeGainAt extends RankedListQualityMetric {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startObject(NAME);
builder.field(SIZE_FIELD.getPreferredName(), this.position);
builder.field(NORMALIZE_FIELD.getPreferredName(), this.normalize);
@ -191,6 +205,26 @@ public class DiscountedCumulativeGainAt extends RankedListQualityMetric {
builder.field(UNKNOWN_DOC_RATING_FIELD.getPreferredName(), this.unknownDocRating);
}
builder.endObject();
builder.endObject();
return builder;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
DiscountedCumulativeGainAt other = (DiscountedCumulativeGainAt) obj;
return Objects.equals(position, other.position) &&
Objects.equals(normalize, other.normalize) &&
Objects.equals(unknownDocRating, other.unknownDocRating);
}
@Override
public final int hashCode() {
return Objects.hash(position, normalize, unknownDocRating);
}
}

View File

@ -32,6 +32,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import javax.naming.directory.SearchResult;
@ -182,4 +183,20 @@ public class PrecisionAtN extends RankedListQualityMetric {
return builder;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
PrecisionAtN other = (PrecisionAtN) obj;
return Objects.equals(n, other.n);
}
@Override
public final int hashCode() {
return Objects.hash(n);
}
}

View File

@ -32,6 +32,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Objects;
/**
* This class defines a ranking evaluation task including an id, a collection of queries to evaluate and the evaluation metric.
@ -43,9 +44,9 @@ import java.util.Collection;
public class RankEvalSpec extends ToXContentToBytes implements Writeable {
/** Collection of query specifications, that is e.g. search request templates to use for query translation. */
private Collection<QuerySpec> specifications = new ArrayList<>();
private Collection<RatedRequest> ratedRequests = new ArrayList<>();
/** Definition of the quality metric, e.g. precision at N */
private RankedListQualityMetric eval;
private RankedListQualityMetric metric;
/** a unique id for the whole QA task */
private String specId;
@ -53,34 +54,34 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
// TODO think if no args ctor is okay
}
public RankEvalSpec(String specId, Collection<QuerySpec> specs, RankedListQualityMetric metric) {
public RankEvalSpec(String specId, Collection<RatedRequest> specs, RankedListQualityMetric metric) {
this.specId = specId;
this.specifications = specs;
this.eval = metric;
this.ratedRequests = specs;
this.metric = metric;
}
public RankEvalSpec(StreamInput in) throws IOException {
int specSize = in.readInt();
specifications = new ArrayList<>(specSize);
ratedRequests = new ArrayList<>(specSize);
for (int i = 0; i < specSize; i++) {
specifications.add(new QuerySpec(in));
ratedRequests.add(new RatedRequest(in));
}
eval = in.readNamedWriteable(RankedListQualityMetric.class);
metric = in.readNamedWriteable(RankedListQualityMetric.class);
specId = in.readString();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(specifications.size());
for (QuerySpec spec : specifications) {
out.writeInt(ratedRequests.size());
for (RatedRequest spec : ratedRequests) {
spec.writeTo(out);
}
out.writeNamedWriteable(eval);
out.writeNamedWriteable(metric);
out.writeString(specId);
}
public void setEval(RankedListQualityMetric eval) {
this.eval = eval;
this.metric = eval;
}
public void setTaskId(String taskId) {
@ -93,22 +94,22 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
/** Returns the precision at n configuration (containing level of n to consider).*/
public RankedListQualityMetric getEvaluator() {
return eval;
return metric;
}
/** Sets the precision at n configuration (containing level of n to consider).*/
public void setEvaluator(RankedListQualityMetric config) {
this.eval = config;
this.metric = config;
}
/** Returns a list of intent to query translation specifications to evaluate. */
public Collection<QuerySpec> getSpecifications() {
return specifications;
public Collection<RatedRequest> getSpecifications() {
return ratedRequests;
}
/** Set the list of intent to query translation specifications to evaluate. */
public void setSpecifications(Collection<QuerySpec> specifications) {
this.specifications = specifications;
public void setSpecifications(Collection<RatedRequest> specifications) {
this.ratedRequests = specifications;
}
private static final ParseField SPECID_FIELD = new ParseField("spec_id");
@ -127,7 +128,7 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
} , METRIC_FIELD);
PARSER.declareObjectArray(RankEvalSpec::setSpecifications, (p, c) -> {
try {
return QuerySpec.fromXContent(p, c);
return RatedRequest.fromXContent(p, c);
} catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing rank request", ex);
}
@ -139,11 +140,11 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
builder.startObject();
builder.field(SPECID_FIELD.getPreferredName(), this.specId);
builder.startArray(REQUESTS_FIELD.getPreferredName());
for (QuerySpec spec : this.specifications) {
for (RatedRequest spec : this.ratedRequests) {
spec.toXContent(builder, params);
}
builder.endArray();
builder.field(METRIC_FIELD.getPreferredName(), this.eval);
builder.field(METRIC_FIELD.getPreferredName(), this.metric);
builder.endObject();
return builder;
}
@ -152,4 +153,22 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable {
return PARSER.parse(parser, context);
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
RankEvalSpec other = (RankEvalSpec) obj;
return Objects.equals(specId, other.specId) &&
Objects.equals(ratedRequests, other.ratedRequests) &&
Objects.equals(metric, other.metric);
}
@Override
public final int hashCode() {
return Objects.hash(specId, ratedRequests, metric);
}
}

View File

@ -21,6 +21,7 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.support.ToXContentToBytes;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -40,7 +41,8 @@ public class RatedDocument extends ToXContentToBytes implements Writeable {
public static final ParseField RATING_FIELD = new ParseField("rating");
public static final ParseField KEY_FIELD = new ParseField("key");
private static final ConstructingObjectParser<RatedDocument, RankEvalContext> PARSER = new ConstructingObjectParser<>("rated_document",
private static final ConstructingObjectParser<RatedDocument, ParseFieldMatcherSupplier> PARSER =
new ConstructingObjectParser<>("rated_document",
a -> new RatedDocument((RatedDocumentKey) a[0], (Integer) a[1]));
static {
@ -93,8 +95,8 @@ public class RatedDocument extends ToXContentToBytes implements Writeable {
out.writeVInt(rating);
}
public static RatedDocument fromXContent(XContentParser parser, RankEvalContext context) throws IOException {
return PARSER.apply(parser, context);
public static RatedDocument fromXContent(XContentParser parser, ParseFieldMatcherSupplier supplier) throws IOException {
return PARSER.apply(parser, supplier);
}
@Override
@ -121,6 +123,6 @@ public class RatedDocument extends ToXContentToBytes implements Writeable {
@Override
public final int hashCode() {
return Objects.hash(getClass(), key, rating);
return Objects.hash(key, rating);
}
}

View File

@ -21,6 +21,7 @@ package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.support.ToXContentToBytes;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParseFieldMatcherSupplier;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
@ -36,7 +37,8 @@ public class RatedDocumentKey extends ToXContentToBytes implements Writeable {
public static final ParseField TYPE_FIELD = new ParseField("type");
public static final ParseField INDEX_FIELD = new ParseField("index");
private static final ConstructingObjectParser<RatedDocumentKey, RankEvalContext> PARSER = new ConstructingObjectParser<>("ratings",
private static final ConstructingObjectParser<RatedDocumentKey, ParseFieldMatcherSupplier> PARSER =
new ConstructingObjectParser<>("ratings",
a -> new RatedDocumentKey((String) a[0], (String) a[1], (String) a[2]));
static {
@ -102,8 +104,9 @@ public class RatedDocumentKey extends ToXContentToBytes implements Writeable {
out.writeString(type);
out.writeString(docId);
}
public static RatedDocumentKey fromXContent(XContentParser parser, RankEvalContext context) throws IOException {
public static RatedDocumentKey fromXContent(
XContentParser parser, ParseFieldMatcherSupplier context) throws IOException {
return PARSER.apply(parser, context);
}
@ -123,6 +126,6 @@ public class RatedDocumentKey extends ToXContentToBytes implements Writeable {
@Override
public final int hashCode() {
return Objects.hash(getClass(), index, type, docId);
return Objects.hash(index, type, docId);
}
}

View File

@ -33,6 +33,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
/**
* Defines a QA specification: All end user supplied query intents will be mapped to the search request specified in this search request
@ -40,7 +41,7 @@ import java.util.List;
*
* The resulting document lists can then be compared against what was specified in the set of rated documents as part of a QAQuery.
* */
public class QuerySpec extends ToXContentToBytes implements Writeable {
public class RatedRequest extends ToXContentToBytes implements Writeable {
private String specId;
private SearchSourceBuilder testRequest;
@ -49,12 +50,12 @@ public class QuerySpec extends ToXContentToBytes implements Writeable {
/** Collection of rated queries for this query QA specification.*/
private List<RatedDocument> ratedDocs = new ArrayList<>();
public QuerySpec() {
public RatedRequest() {
// ctor that doesn't require all args to be present immediatly is easier to use with ObjectParser
// TODO decide if we can require only id as mandatory, set default values for the rest?
}
public QuerySpec(String specId, SearchSourceBuilder testRequest, List<String> indices, List<String> types,
public RatedRequest(String specId, SearchSourceBuilder testRequest, List<String> indices, List<String> types,
List<RatedDocument> ratedDocs) {
this.specId = specId;
this.testRequest = testRequest;
@ -63,7 +64,7 @@ public class QuerySpec extends ToXContentToBytes implements Writeable {
this.ratedDocs = ratedDocs;
}
public QuerySpec(StreamInput in) throws IOException {
public RatedRequest(StreamInput in) throws IOException {
this.specId = in.readString();
testRequest = new SearchSourceBuilder(in);
int indicesSize = in.readInt();
@ -148,18 +149,18 @@ public class QuerySpec extends ToXContentToBytes implements Writeable {
private static final ParseField ID_FIELD = new ParseField("id");
private static final ParseField REQUEST_FIELD = new ParseField("request");
private static final ParseField RATINGS_FIELD = new ParseField("ratings");
private static final ObjectParser<QuerySpec, RankEvalContext> PARSER = new ObjectParser<>("requests", QuerySpec::new);
private static final ObjectParser<RatedRequest, RankEvalContext> PARSER = new ObjectParser<>("requests", RatedRequest::new);
static {
PARSER.declareString(QuerySpec::setSpecId, ID_FIELD);
PARSER.declareObject(QuerySpec::setTestRequest, (p, c) -> {
PARSER.declareString(RatedRequest::setSpecId, ID_FIELD);
PARSER.declareObject(RatedRequest::setTestRequest, (p, c) -> {
try {
return SearchSourceBuilder.fromXContent(c.getParseContext(), c.getAggs(), c.getSuggesters());
} catch (IOException ex) {
throw new ParsingException(p.getTokenLocation(), "error parsing request", ex);
}
} , REQUEST_FIELD);
PARSER.declareObjectArray(QuerySpec::setRatedDocs, (p, c) -> {
PARSER.declareObjectArray(RatedRequest::setRatedDocs, (p, c) -> {
try {
return RatedDocument.fromXContent(p, c);
} catch (IOException ex) {
@ -169,7 +170,7 @@ public class QuerySpec extends ToXContentToBytes implements Writeable {
}
/**
* Parses {@link QuerySpec} from rest representation:
* Parses {@link RatedRequest} from rest representation:
*
* Example:
* {
@ -188,7 +189,7 @@ public class QuerySpec extends ToXContentToBytes implements Writeable {
* "ratings": [{ "1": 1 }, { "2": 0 }, { "3": 1 } ]
* }
*/
public static QuerySpec fromXContent(XContentParser parser, RankEvalContext context) throws IOException {
public static RatedRequest fromXContent(XContentParser parser, RankEvalContext context) throws IOException {
return PARSER.parse(parser, context);
}
@ -205,4 +206,25 @@ public class QuerySpec extends ToXContentToBytes implements Writeable {
builder.endObject();
return builder;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
RatedRequest other = (RatedRequest) obj;
return Objects.equals(specId, other.specId) &&
Objects.equals(testRequest, other.testRequest) &&
Objects.equals(indices, other.indices) &&
Objects.equals(types, other.types) &&
Objects.equals(ratedDocs, other.ratedDocs);
}
@Override
public final int hashCode() {
return Objects.hash(specId, testRequest, indices.hashCode(), types.hashCode(), ratedDocs.hashCode());
}
}

View File

@ -33,6 +33,7 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.naming.directory.SearchResult;
@ -171,4 +172,21 @@ public class ReciprocalRank extends RankedListQualityMetric {
builder.endObject();
return builder;
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
ReciprocalRank other = (ReciprocalRank) obj;
return Objects.equals(maxAcceptableRank, other.maxAcceptableRank);
}
@Override
public final int hashCode() {
return Objects.hash(maxAcceptableRank);
}
}

View File

@ -196,7 +196,7 @@ public class RestRankEvalAction extends BaseRestHandler {
List<String> indices = Arrays.asList(Strings.splitStringByCommaToArray(request.param("index")));
List<String> types = Arrays.asList(Strings.splitStringByCommaToArray(request.param("type")));
RankEvalSpec spec = RankEvalSpec.parse(context.parser(), context);
for (QuerySpec specification : spec.getSpecifications()) {
for (RatedRequest specification : spec.getSpecifications()) {
specification.setIndices(indices);
specification.setTypes(types);
};

View File

@ -64,11 +64,11 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
RankEvalSpec qualityTask = request.getRankEvalSpec();
Map<String, Collection<RatedDocumentKey>> unknownDocs = new ConcurrentHashMap<>();
Collection<QuerySpec> specifications = qualityTask.getSpecifications();
Collection<RatedRequest> specifications = qualityTask.getSpecifications();
AtomicInteger responseCounter = new AtomicInteger(specifications.size());
Map<String, EvalQueryQuality> partialResults = new ConcurrentHashMap<>(specifications.size());
for (QuerySpec querySpecification : specifications) {
for (RatedRequest querySpecification : specifications) {
final RankEvalActionListener searchListener = new RankEvalActionListener(listener, qualityTask, querySpecification,
partialResults, unknownDocs, responseCounter);
SearchSourceBuilder specRequest = querySpecification.getTestRequest();
@ -85,13 +85,13 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
public static class RankEvalActionListener implements ActionListener<SearchResponse> {
private ActionListener<RankEvalResponse> listener;
private QuerySpec specification;
private RatedRequest specification;
private Map<String, EvalQueryQuality> partialResults;
private RankEvalSpec task;
private Map<String, Collection<RatedDocumentKey>> unknownDocs;
private AtomicInteger responseCounter;
public RankEvalActionListener(ActionListener<RankEvalResponse> listener, RankEvalSpec task, QuerySpec specification,
public RankEvalActionListener(ActionListener<RankEvalResponse> listener, RankEvalSpec task, RatedRequest specification,
Map<String, EvalQueryQuality> partialResults, Map<String, Collection<RatedDocumentKey>> unknownDocs,
AtomicInteger responseCounter) {
this.listener = listener;

View File

@ -121,4 +121,22 @@ public class DiscountedCumulativeGainAtTests extends ESTestCase {
assertEquals(8, dcgAt.getPosition());
assertEquals(true, dcgAt.getNormalize());
}
public static DiscountedCumulativeGainAt createTestItem() {
int position = randomIntBetween(0, 1000);
boolean normalize = randomBoolean();
Integer unknownDocRating = new Integer(randomIntBetween(0, 1000));
return new DiscountedCumulativeGainAt(position, normalize, unknownDocRating);
}
public void testXContentRoundtrip() throws IOException {
DiscountedCumulativeGainAt testItem = createTestItem();
XContentParser itemParser = XContentTestHelper.roundtrip(testItem);
itemParser.nextToken();
itemParser.nextToken();
DiscountedCumulativeGainAt parsedItem = DiscountedCumulativeGainAt.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
}

View File

@ -134,4 +134,20 @@ public class PrecisionAtNTests extends ESTestCase {
partialResults.add(new EvalQueryQuality(0.6, emptyList()));
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
}
public static PrecisionAtN createTestItem() {
int position = randomIntBetween(0, 1000);
return new PrecisionAtN(position);
}
public void testXContentRoundtrip() throws IOException {
PrecisionAtN testItem = createTestItem();
XContentParser itemParser = XContentTestHelper.roundtrip(testItem);
itemParser.nextToken();
itemParser.nextToken();
PrecisionAtN parsedItem = PrecisionAtN.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
}

View File

@ -35,7 +35,6 @@ import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE)
public class RankEvalRequestTests extends ESIntegTestCase {
@Override
protected Collection<Class<? extends Plugin>> transportClientPlugins() {
@ -72,11 +71,11 @@ public class RankEvalRequestTests extends ESIntegTestCase {
List<String> types = Arrays.asList(new String[] { "testtype" });
String specId = randomAsciiOfLength(10);
List<QuerySpec> specifications = new ArrayList<>();
List<RatedRequest> specifications = new ArrayList<>();
SearchSourceBuilder testQuery = new SearchSourceBuilder();
testQuery.query(new MatchAllQueryBuilder());
specifications.add(new QuerySpec("amsterdam_query", testQuery, indices, types, createRelevant("2", "3", "4", "5")));
specifications.add(new QuerySpec("berlin_query", testQuery, indices, types, createRelevant("1")));
specifications.add(new RatedRequest("amsterdam_query", testQuery, indices, types, createRelevant("2", "3", "4", "5")));
specifications.add(new RatedRequest("berlin_query", testQuery, indices, types, createRelevant("1")));
RankEvalSpec task = new RankEvalSpec(specId, specifications, new PrecisionAtN(10));
@ -106,14 +105,14 @@ public class RankEvalRequestTests extends ESIntegTestCase {
List<String> types = Arrays.asList(new String[] { "testtype" });
String specId = randomAsciiOfLength(10);
List<QuerySpec> specifications = new ArrayList<>();
List<RatedRequest> specifications = new ArrayList<>();
SearchSourceBuilder amsterdamQuery = new SearchSourceBuilder();
amsterdamQuery.query(new MatchAllQueryBuilder());
specifications.add(new QuerySpec("amsterdam_query", amsterdamQuery, indices, types, createRelevant("2", "3", "4", "5")));
specifications.add(new RatedRequest("amsterdam_query", amsterdamQuery, indices, types, createRelevant("2", "3", "4", "5")));
SearchSourceBuilder brokenQuery = new SearchSourceBuilder();
RangeQueryBuilder brokenRangeQuery = new RangeQueryBuilder("text").timeZone("CET");
brokenQuery.query(brokenRangeQuery);
specifications.add(new QuerySpec("broken_query", brokenQuery, indices, types, createRelevant("1")));
specifications.add(new RatedRequest("broken_query", brokenQuery, indices, types, createRelevant("1")));
RankEvalSpec task = new RankEvalSpec(specId, specifications, new PrecisionAtN(10));

View File

@ -0,0 +1,107 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ParseFieldRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.indices.query.IndicesQueriesRegistry;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.SearchRequestParsers;
import org.elasticsearch.search.aggregations.AggregatorParsers;
import org.elasticsearch.search.suggest.Suggesters;
import org.elasticsearch.test.ESTestCase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static java.util.Collections.emptyList;
public class RankEvalSpecTests extends ESTestCase {
private static SearchModule searchModule;
private static SearchRequestParsers searchRequestParsers;
/**
* setup for the whole base test class
*/
@BeforeClass
public static void init() throws IOException {
AggregatorParsers aggsParsers = new AggregatorParsers(new ParseFieldRegistry<>("aggregation"),
new ParseFieldRegistry<>("aggregation_pipes"));
searchModule = new SearchModule(Settings.EMPTY, false, emptyList());
IndicesQueriesRegistry queriesRegistry = searchModule.getQueryParserRegistry();
Suggesters suggesters = searchModule.getSuggesters();
searchRequestParsers = new SearchRequestParsers(queriesRegistry, aggsParsers, suggesters);
}
@AfterClass
public static void afterClass() throws Exception {
searchModule = null;
searchRequestParsers = null;
}
public void testRoundtripping() throws IOException {
List<String> indices = new ArrayList<>();
int size = randomIntBetween(0, 20);
for (int i = 0; i < size; i++) {
indices.add(randomAsciiOfLengthBetween(0, 50));
}
List<String> types = new ArrayList<>();
size = randomIntBetween(0, 20);
for (int i = 0; i < size; i++) {
types.add(randomAsciiOfLengthBetween(0, 50));
}
List<RatedRequest> specs = new ArrayList<>();
size = randomIntBetween(1, 2); // TODO I guess requests with no query spec should be rejected...
for (int i = 0; i < size; i++) {
specs.add(RatedRequestsTests.createTestItem(indices, types));
}
String specId = randomAsciiOfLengthBetween(1, 10); // TODO we should reject zero length ids ...
RankedListQualityMetric metric;
if (randomBoolean()) {
metric = PrecisionAtNTests.createTestItem();
} else {
metric = DiscountedCumulativeGainAtTests.createTestItem();
}
RankEvalSpec testItem = new RankEvalSpec(specId, specs, metric);
XContentParser itemParser = XContentTestHelper.roundtrip(testItem);
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT);
RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext,
searchRequestParsers);
RankEvalSpec parsedItem = RankEvalSpec.parse(itemParser, rankContext);
// IRL these come from URL parameters - see RestRankEvalAction
parsedItem.getSpecifications().stream().forEach(e -> {e.setIndices(indices); e.setTypes(types);});
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
}

View File

@ -0,0 +1,42 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
public class RatedDocumentKeyTests extends ESTestCase {
public void testXContentRoundtrip() throws IOException {
String index = randomAsciiOfLengthBetween(0, 10);
String type = randomAsciiOfLengthBetween(0, 10);
String docId = randomAsciiOfLengthBetween(0, 10);
RatedDocumentKey testItem = new RatedDocumentKey(index, type, docId);
XContentParser itemParser = XContentTestHelper.roundtrip(testItem);
RatedDocumentKey parsedItem = RatedDocumentKey.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
}

View File

@ -20,40 +20,28 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
public class RatedDocumentTests extends ESTestCase {
public void testXContentParsing() throws IOException {
public static RatedDocument createTestItem() {
String index = randomAsciiOfLength(10);
String type = randomAsciiOfLength(10);
String docId = randomAsciiOfLength(10);
int rating = randomInt();
RatedDocument testItem = new RatedDocument(new RatedDocumentKey(index, type, docId), rating);
XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
if (randomBoolean()) {
builder.prettyPrint();
}
testItem.toXContent(builder, ToXContent.EMPTY_PARAMS);
XContentBuilder shuffled = shuffleXContent(builder);
XContentParser itemParser = XContentHelper.createParser(shuffled.bytes());
itemParser.nextToken();
return new RatedDocument(new RatedDocumentKey(index, type, docId), rating);
}
RankEvalContext context = new RankEvalContext(ParseFieldMatcher.STRICT, null, null);
RatedDocument parsedItem = RatedDocument.fromXContent(itemParser, context);
public void testXContentParsing() throws IOException {
RatedDocument testItem = createTestItem();
XContentParser itemParser = XContentTestHelper.roundtrip(testItem);
RatedDocument parsedItem = RatedDocument.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
}

View File

@ -24,22 +24,25 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ParseFieldRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.indices.query.IndicesQueriesRegistry;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.SearchRequestParsers;
import org.elasticsearch.search.aggregations.AggregatorParsers;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.suggest.Suggesters;
import org.elasticsearch.test.ESTestCase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static java.util.Collections.emptyList;
public class QuerySpecTests extends ESTestCase {
public class RatedRequestsTests extends ESTestCase {
private static SearchModule searchModule;
private static SearchRequestParsers searchRequestParsers;
@ -63,7 +66,51 @@ public class QuerySpecTests extends ESTestCase {
searchRequestParsers = null;
}
// TODO add some sort of roundtrip testing like we have now for queries?
public static RatedRequest createTestItem(List<String> indices, List<String> types) {
String specId = randomAsciiOfLength(50);
SearchSourceBuilder testRequest = new SearchSourceBuilder();
testRequest.size(randomInt());
testRequest.query(new MatchAllQueryBuilder());
List<RatedDocument> ratedDocs = new ArrayList<>();
int size = randomIntBetween(0, 2);
for (int i = 0; i < size; i++) {
ratedDocs.add(RatedDocumentTests.createTestItem());
}
return new RatedRequest(specId, testRequest, indices, types, ratedDocs);
}
public void testXContentRoundtrip() throws IOException {
List<String> indices = new ArrayList<>();
int size = randomIntBetween(0, 20);
for (int i = 0; i < size; i++) {
indices.add(randomAsciiOfLengthBetween(0, 50));
}
List<String> types = new ArrayList<>();
size = randomIntBetween(0, 20);
for (int i = 0; i < size; i++) {
types.add(randomAsciiOfLengthBetween(0, 50));
}
RatedRequest testItem = createTestItem(indices, types);
XContentParser itemParser = XContentTestHelper.roundtrip(testItem);
itemParser.nextToken();
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, itemParser, ParseFieldMatcher.STRICT);
RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext,
searchRequestParsers);
RatedRequest parsedItem = RatedRequest.fromXContent(itemParser, rankContext);
parsedItem.setIndices(indices); // IRL these come from URL parameters - see RestRankEvalAction
parsedItem.setTypes(types); // IRL these come from URL parameters - see RestRankEvalAction
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
public void testParseFromXContent() throws IOException {
String querySpecString = " {\n"
+ " \"id\": \"my_qa_query\",\n"
@ -87,7 +134,7 @@ public class QuerySpecTests extends ESTestCase {
QueryParseContext queryContext = new QueryParseContext(searchRequestParsers.queryParsers, parser, ParseFieldMatcher.STRICT);
RankEvalContext rankContext = new RankEvalContext(ParseFieldMatcher.STRICT, queryContext,
searchRequestParsers);
QuerySpec specification = QuerySpec.fromXContent(parser, rankContext);
RatedRequest specification = RatedRequest.fromXContent(parser, rankContext);
assertEquals("my_qa_query", specification.getSpecId());
assertNotNull(specification.getTestRequest());
List<RatedDocument> ratedDocs = specification.getRatedDocs();
@ -99,4 +146,6 @@ public class QuerySpecTests extends ESTestCase {
assertEquals("3", ratedDocs.get(2).getKey().getDocID());
assertEquals(1, ratedDocs.get(2).getRating());
}
}

View File

@ -19,7 +19,9 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.rankeval.PrecisionAtN.Rating;
import org.elasticsearch.search.SearchShardTarget;
@ -148,4 +150,18 @@ public class ReciprocalRankTests extends ESTestCase {
EvalQueryQuality evaluation = reciprocalRank.evaluate(hits, ratedDocs);
assertEquals(0.0, evaluation.getQualityLevel(), Double.MIN_VALUE);
}
public void testXContentRoundtrip() throws IOException {
int position = randomIntBetween(0, 1000);
ReciprocalRank testItem = new ReciprocalRank(position);
XContentParser itemParser = XContentTestHelper.roundtrip(testItem);
itemParser.nextToken();
itemParser.nextToken();
ReciprocalRank parsedItem = ReciprocalRank.fromXContent(itemParser, () -> ParseFieldMatcher.STRICT);
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
}

View File

@ -0,0 +1,45 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.support.ToXContentToBytes;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
public class XContentTestHelper {
public static XContentParser roundtrip(ToXContentToBytes testItem) throws IOException {
XContentBuilder builder = XContentFactory.contentBuilder(ESTestCase.randomFrom(XContentType.values()));
if (ESTestCase.randomBoolean()) {
builder.prettyPrint();
}
testItem.toXContent(builder, ToXContent.EMPTY_PARAMS);
XContentBuilder shuffled = ESTestCase.shuffleXContent(builder);
XContentParser itemParser = XContentHelper.createParser(shuffled.bytes());
return itemParser;
}
}