Serialisation and validation checks for rank evaluation request components (#21975)

Adds tests around serialisation/validation checks for rank evaluation request components

* Add null/ empty string checks to RatedDocument constructor
* Add mutation test to RatedDocument serialization tests.
* Reorganise rank-eval RatedDocument tests and add serialisation test.
* Add roundtrip serialisation testing for RatedRequests
* Adds serialisation testing and equals/hashcode testing for RatedRequest.
* Fixes a bug in previous equals implementation of RatedRequest along the way.
* Add roundtrip tests for Precision and ReciprocalRank
* Also fixes a bug with serialising ReciprocalRank.
* Add roundtrip testing for DiscountedCumulativeGain
* Add serialisation test for DocumentKey and fix test init
* Add check that relevant doc threshold is always positive for precision.
* Check that relevant threshold is always positive for precision and reciprocal
rank

Closes #21401
This commit is contained in:
Isabel Drost-Fromm 2016-12-07 11:47:47 +01:00 committed by GitHub
parent 5c6cdb90ad
commit 165cec2757
10 changed files with 267 additions and 8 deletions

View File

@ -20,6 +20,7 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.support.ToXContentToBytes;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
@ -47,6 +48,16 @@ public class DocumentKey extends ToXContentToBytes implements Writeable {
}
public DocumentKey(String index, String type, String docId) {
if (Strings.isNullOrEmpty(index)) {
throw new IllegalArgumentException("Index must be set for each rated document");
}
if(Strings.isNullOrEmpty(type)) {
throw new IllegalArgumentException("Type must be set for each rated document");
}
if (Strings.isNullOrEmpty(docId)) {
throw new IllegalArgumentException("DocId must be set for each rated document");
}
this.index = index;
this.type = type;
this.docId = docId;

View File

@ -90,6 +90,9 @@ public class Precision implements RankedListQualityMetric {
* Sets the rating threshold above which ratings are considered to be "relevant" for this metric.
* */
public void setRelevantRatingThreshhold(int threshold) {
if (threshold < 0) {
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
}
this.relevantRatingThreshhold = threshold;
}

View File

@ -289,12 +289,14 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
RatedRequest other = (RatedRequest) obj;
return Objects.equals(specId, other.specId) &&
Objects.equals(testRequest, other.testRequest) &&
Objects.equals(indices, other.indices) &&
Objects.equals(types, other.types) &&
Objects.equals(summaryFields, summaryFields) &&
Objects.equals(summaryFields, other.summaryFields) &&
Objects.equals(ratedDocs, other.ratedDocs) &&
Objects.equals(params, other.params);
}

View File

@ -57,7 +57,7 @@ public class ReciprocalRank implements RankedListQualityMetric {
}
public ReciprocalRank(StreamInput in) throws IOException {
this.relevantRatingThreshhold = in.readInt();
this.relevantRatingThreshhold = in.readVInt();
}
@Override
@ -69,6 +69,10 @@ public class ReciprocalRank implements RankedListQualityMetric {
* Sets the rating threshold above which ratings are considered to be "relevant" for this metric.
* */
public void setRelevantRatingThreshhold(int threshold) {
if (threshold < 0) {
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
}
this.relevantRatingThreshhold = threshold;
}

View File

@ -214,4 +214,32 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
public void testSerialization() throws IOException {
DiscountedCumulativeGain original = createTestItem();
DiscountedCumulativeGain deserialized = RankEvalTestHelper.copy(original, DiscountedCumulativeGain::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
DiscountedCumulativeGain testItem = createTestItem();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
RankEvalTestHelper.copy(testItem, DiscountedCumulativeGain::new));
}
private DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) {
boolean normalise = original.getNormalize();
int unknownDocRating = original.getUnknownDocRating();
DiscountedCumulativeGain gain = new DiscountedCumulativeGain();
gain.setNormalize(normalise);
gain.setUnknownDocRating(unknownDocRating);
List<Runnable> mutators = new ArrayList<>();
mutators.add(() -> gain.setNormalize(! original.getNormalize()));
mutators.add(() -> gain.setUnknownDocRating(randomValueOtherThan(unknownDocRating, () -> randomIntBetween(0, 10))));
randomFrom(mutators).run();
return gain;
}
}

View File

@ -27,13 +27,13 @@ import java.io.IOException;
public class DocumentKeyTests extends ESTestCase {
static DocumentKey createRandomRatedDocumentKey() {
String index = randomAsciiOfLengthBetween(0, 10);
String type = randomAsciiOfLengthBetween(0, 10);
String docId = randomAsciiOfLengthBetween(0, 10);
String index = randomAsciiOfLengthBetween(1, 10);
String type = randomAsciiOfLengthBetween(1, 10);
String docId = randomAsciiOfLengthBetween(1, 10);
return new DocumentKey(index, type, docId);
}
public DocumentKey createRandomTestItem() {
public DocumentKey createTestItem() {
return createRandomRatedDocumentKey();
}
@ -62,4 +62,14 @@ public class DocumentKeyTests extends ESTestCase {
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
new DocumentKey(testItem.getIndex(), testItem.getType(), testItem.getDocID()));
}
public void testSerialization() throws IOException {
DocumentKey original = createTestItem();
DocumentKey deserialized = RankEvalTestHelper.copy(original, DocumentKey::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
}

View File

@ -167,6 +167,11 @@ public class PrecisionTests extends ESTestCase {
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
}
public void testInvalidRelevantThreshold() {
Precision prez = new Precision();
expectThrows(IllegalArgumentException.class, () -> prez.setRelevantRatingThreshhold(-1));
}
public static Precision createTestItem() {
Precision precision = new Precision();
if (randomBoolean()) {
@ -186,6 +191,34 @@ public class PrecisionTests extends ESTestCase {
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
public void testSerialization() throws IOException {
Precision original = createTestItem();
Precision deserialized = RankEvalTestHelper.copy(original, Precision::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
Precision testItem = createTestItem();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
RankEvalTestHelper.copy(testItem, Precision::new));
}
private Precision mutateTestItem(Precision original) {
boolean ignoreUnlabeled = original.getIgnoreUnlabeled();
int relevantThreshold = original.getRelevantRatingThreshold();
Precision precision = new Precision();
precision.setIgnoreUnlabeled(ignoreUnlabeled);
precision.setRelevantRatingThreshhold(relevantThreshold);
List<Runnable> mutators = new ArrayList<>();
mutators.add(() -> precision.setIgnoreUnlabeled(! ignoreUnlabeled));
mutators.add(() -> precision.setRelevantRatingThreshhold(randomValueOtherThan(relevantThreshold, () -> randomIntBetween(0, 10))));
randomFrom(mutators).run();
return precision;
}
private static SearchHit[] toSearchHits(List<RatedDocument> rated, String index, String type) {
InternalSearchHit[] hits = new InternalSearchHit[rated.size()];

View File

@ -44,4 +44,52 @@ public class RatedDocumentTests extends ESTestCase {
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
public void testSerialization() throws IOException {
RatedDocument original = createRatedDocument();
RatedDocument deserialized = RankEvalTestHelper.copy(original, RatedDocument::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
RatedDocument testItem = createRatedDocument();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
RankEvalTestHelper.copy(testItem, RatedDocument::new));
}
public void testInvalidParsing() throws IOException {
expectThrows(IllegalArgumentException.class, () -> new RatedDocument(null, "abc", "abc", 10));
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("", "abc", "abc", 10));
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", null, "abc", 10));
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", "", "abc", 10));
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", "abc", null, 10));
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", "abc", "", 10));
}
private static RatedDocument mutateTestItem(RatedDocument original) {
int rating = original.getRating();
String index = original.getIndex();
String type = original.getType();
String docId = original.getDocID();
switch (randomIntBetween(0, 3)) {
case 0:
rating = randomValueOtherThan(rating, () -> randomInt());
break;
case 1:
index = randomValueOtherThan(index, () -> randomAsciiOfLength(10));
break;
case 2:
type = randomValueOtherThan(type, () -> randomAsciiOfLength(10));
break;
case 3:
docId = randomValueOtherThan(docId, () -> randomAsciiOfLength(10));
break;
default:
throw new IllegalStateException("The test should only allow two parameters mutated");
}
return new RatedDocument(index, type, docId, rating);
}
}

View File

@ -20,11 +20,13 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ParseFieldRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.indices.query.IndicesQueriesRegistry;
import org.elasticsearch.search.SearchModule;
@ -131,6 +133,92 @@ public class RatedRequestsTests extends ESTestCase {
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
public void testSerialization() throws IOException {
List<String> indices = new ArrayList<>();
int size = randomIntBetween(0, 20);
for (int i = 0; i < size; i++) {
indices.add(randomAsciiOfLengthBetween(0, 50));
}
List<String> types = new ArrayList<>();
size = randomIntBetween(0, 20);
for (int i = 0; i < size; i++) {
types.add(randomAsciiOfLengthBetween(0, 50));
}
RatedRequest original = createTestItem(indices, types);
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
RatedRequest deserialized = RankEvalTestHelper.copy(original, RatedRequest::new, new NamedWriteableRegistry(namedWriteables));
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
List<String> indices = new ArrayList<>();
int size = randomIntBetween(0, 20);
for (int i = 0; i < size; i++) {
indices.add(randomAsciiOfLengthBetween(0, 50));
}
List<String> types = new ArrayList<>();
size = randomIntBetween(0, 20);
for (int i = 0; i < size; i++) {
types.add(randomAsciiOfLengthBetween(0, 50));
}
RatedRequest testItem = createTestItem(indices, types);
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
RankEvalTestHelper.copy(testItem, RatedRequest::new, new NamedWriteableRegistry(namedWriteables)));
}
private RatedRequest mutateTestItem(RatedRequest original) {
String specId = original.getSpecId();
int size = original.getTestRequest().size();
List<RatedDocument> ratedDocs = original.getRatedDocs();
List<String> indices = original.getIndices();
List<String> types = original.getTypes();
Map<String, Object> params = original.getParams();
List<String> summaryFields = original.getSummaryFields();
SearchSourceBuilder testRequest = new SearchSourceBuilder();
testRequest.size(size);
testRequest.query(new MatchAllQueryBuilder());
RatedRequest ratedRequest = new RatedRequest(specId, testRequest, indices, types, ratedDocs);
ratedRequest.setIndices(indices);
ratedRequest.setTypes(types);
ratedRequest.setParams(params);
ratedRequest.setSummaryFields(summaryFields);
List<Runnable> mutators = new ArrayList<>();
mutators.add(() -> ratedRequest.setSpecId(randomValueOtherThan(specId, () -> randomAsciiOfLength(10))));
mutators.add(() -> ratedRequest.getTestRequest().size(randomValueOtherThan(size, () -> randomInt())));
mutators.add(() -> ratedRequest.setRatedDocs(
Arrays.asList(randomValueOtherThanMany(ratedDocs::contains, () -> RatedDocumentTests.createRatedDocument()))));
mutators.add(() -> ratedRequest.setIndices(
Arrays.asList(randomValueOtherThanMany(indices::contains, () -> randomAsciiOfLength(10)))));
HashMap<String, Object> modified = new HashMap<>();
modified.putAll(params);
modified.put("one_more_key", "one_more_value");
mutators.add(() -> ratedRequest.setParams(modified));
mutators.add(() -> ratedRequest.setSummaryFields(
Arrays.asList(randomValueOtherThanMany(summaryFields::contains, () -> randomAsciiOfLength(10)))));
randomFrom(mutators).run();
return ratedRequest;
}
public void testDuplicateRatedDocThrowsException() {
RatedRequest request = createTestItem(Arrays.asList("index"), Arrays.asList("type"));

View File

@ -123,8 +123,7 @@ public class ReciprocalRankTests extends ESTestCase {
}
public void testXContentRoundtrip() throws IOException {
ReciprocalRank testItem = new ReciprocalRank();
testItem.setRelevantRatingThreshhold(randomIntBetween(0, 20));
ReciprocalRank testItem = createTestItem();
XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem);
itemParser.nextToken();
itemParser.nextToken();
@ -146,4 +145,37 @@ public class ReciprocalRankTests extends ESTestCase {
}
return hits;
}
private ReciprocalRank createTestItem() {
ReciprocalRank testItem = new ReciprocalRank();
testItem.setRelevantRatingThreshhold(randomIntBetween(0, 20));
return testItem;
}
public void testSerialization() throws IOException {
ReciprocalRank original = createTestItem();
ReciprocalRank deserialized = RankEvalTestHelper.copy(original, ReciprocalRank::new);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
ReciprocalRank testItem = createTestItem();
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
RankEvalTestHelper.copy(testItem, ReciprocalRank::new));
}
private ReciprocalRank mutateTestItem(ReciprocalRank testItem) {
int relevantThreshold = testItem.getRelevantRatingThreshold();
ReciprocalRank rank = new ReciprocalRank();
rank.setRelevantRatingThreshhold(randomValueOtherThan(relevantThreshold, () -> randomIntBetween(0, 10)));
return rank;
}
public void testInvalidRelevantThreshold() {
ReciprocalRank prez = new ReciprocalRank();
expectThrows(IllegalArgumentException.class, () -> prez.setRelevantRatingThreshhold(-1));
}
}