Serialisation and validation checks for rank evaluation request components (#21975)
Adds tests around serialisation/validation checks for rank evaluation request components * Add null/ empty string checks to RatedDocument constructor * Add mutation test to RatedDocument serialization tests. * Reorganise rank-eval RatedDocument tests and add serialisation test. * Add roundtrip serialisation testing for RatedRequests * Adds serialisation testing and equals/hashcode testing for RatedRequest. * Fixes a bug in previous equals implementation of RatedRequest along the way. * Add roundtrip tests for Precision and ReciprocalRank * Also fixes a bug with serialising ReciprocalRank. * Add roundtrip testing for DiscountedCumulativeGain * Add serialisation test for DocumentKey and fix test init * Add check that relevant doc threshold is always positive for precision. * Check that relevant threshold is always positive for precision and reciprocal rank Closes #21401
This commit is contained in:
parent
5c6cdb90ad
commit
165cec2757
|
@ -20,6 +20,7 @@
|
|||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.action.support.ToXContentToBytes;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
|
@ -47,6 +48,16 @@ public class DocumentKey extends ToXContentToBytes implements Writeable {
|
|||
}
|
||||
|
||||
public DocumentKey(String index, String type, String docId) {
|
||||
if (Strings.isNullOrEmpty(index)) {
|
||||
throw new IllegalArgumentException("Index must be set for each rated document");
|
||||
}
|
||||
if(Strings.isNullOrEmpty(type)) {
|
||||
throw new IllegalArgumentException("Type must be set for each rated document");
|
||||
}
|
||||
if (Strings.isNullOrEmpty(docId)) {
|
||||
throw new IllegalArgumentException("DocId must be set for each rated document");
|
||||
}
|
||||
|
||||
this.index = index;
|
||||
this.type = type;
|
||||
this.docId = docId;
|
||||
|
|
|
@ -90,6 +90,9 @@ public class Precision implements RankedListQualityMetric {
|
|||
* Sets the rating threshold above which ratings are considered to be "relevant" for this metric.
|
||||
* */
|
||||
public void setRelevantRatingThreshhold(int threshold) {
|
||||
if (threshold < 0) {
|
||||
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
|
||||
}
|
||||
this.relevantRatingThreshhold = threshold;
|
||||
}
|
||||
|
||||
|
|
|
@ -289,12 +289,14 @@ public class RatedRequest extends ToXContentToBytes implements Writeable {
|
|||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
RatedRequest other = (RatedRequest) obj;
|
||||
|
||||
return Objects.equals(specId, other.specId) &&
|
||||
Objects.equals(testRequest, other.testRequest) &&
|
||||
Objects.equals(indices, other.indices) &&
|
||||
Objects.equals(types, other.types) &&
|
||||
Objects.equals(summaryFields, summaryFields) &&
|
||||
Objects.equals(summaryFields, other.summaryFields) &&
|
||||
Objects.equals(ratedDocs, other.ratedDocs) &&
|
||||
Objects.equals(params, other.params);
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ public class ReciprocalRank implements RankedListQualityMetric {
|
|||
}
|
||||
|
||||
public ReciprocalRank(StreamInput in) throws IOException {
|
||||
this.relevantRatingThreshhold = in.readInt();
|
||||
this.relevantRatingThreshhold = in.readVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -69,6 +69,10 @@ public class ReciprocalRank implements RankedListQualityMetric {
|
|||
* Sets the rating threshold above which ratings are considered to be "relevant" for this metric.
|
||||
* */
|
||||
public void setRelevantRatingThreshhold(int threshold) {
|
||||
if (threshold < 0) {
|
||||
throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
|
||||
}
|
||||
|
||||
this.relevantRatingThreshhold = threshold;
|
||||
}
|
||||
|
||||
|
|
|
@ -214,4 +214,32 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
|
|||
assertEquals(testItem, parsedItem);
|
||||
assertEquals(testItem.hashCode(), parsedItem.hashCode());
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
DiscountedCumulativeGain original = createTestItem();
|
||||
DiscountedCumulativeGain deserialized = RankEvalTestHelper.copy(original, DiscountedCumulativeGain::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
DiscountedCumulativeGain testItem = createTestItem();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
RankEvalTestHelper.copy(testItem, DiscountedCumulativeGain::new));
|
||||
}
|
||||
|
||||
private DiscountedCumulativeGain mutateTestItem(DiscountedCumulativeGain original) {
|
||||
boolean normalise = original.getNormalize();
|
||||
int unknownDocRating = original.getUnknownDocRating();
|
||||
DiscountedCumulativeGain gain = new DiscountedCumulativeGain();
|
||||
gain.setNormalize(normalise);
|
||||
gain.setUnknownDocRating(unknownDocRating);
|
||||
|
||||
List<Runnable> mutators = new ArrayList<>();
|
||||
mutators.add(() -> gain.setNormalize(! original.getNormalize()));
|
||||
mutators.add(() -> gain.setUnknownDocRating(randomValueOtherThan(unknownDocRating, () -> randomIntBetween(0, 10))));
|
||||
randomFrom(mutators).run();
|
||||
return gain;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,13 +27,13 @@ import java.io.IOException;
|
|||
public class DocumentKeyTests extends ESTestCase {
|
||||
|
||||
static DocumentKey createRandomRatedDocumentKey() {
|
||||
String index = randomAsciiOfLengthBetween(0, 10);
|
||||
String type = randomAsciiOfLengthBetween(0, 10);
|
||||
String docId = randomAsciiOfLengthBetween(0, 10);
|
||||
String index = randomAsciiOfLengthBetween(1, 10);
|
||||
String type = randomAsciiOfLengthBetween(1, 10);
|
||||
String docId = randomAsciiOfLengthBetween(1, 10);
|
||||
return new DocumentKey(index, type, docId);
|
||||
}
|
||||
|
||||
public DocumentKey createRandomTestItem() {
|
||||
public DocumentKey createTestItem() {
|
||||
return createRandomRatedDocumentKey();
|
||||
}
|
||||
|
||||
|
@ -62,4 +62,14 @@ public class DocumentKeyTests extends ESTestCase {
|
|||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
new DocumentKey(testItem.getIndex(), testItem.getType(), testItem.getDocID()));
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
DocumentKey original = createTestItem();
|
||||
DocumentKey deserialized = RankEvalTestHelper.copy(original, DocumentKey::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -167,6 +167,11 @@ public class PrecisionTests extends ESTestCase {
|
|||
assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE);
|
||||
}
|
||||
|
||||
public void testInvalidRelevantThreshold() {
|
||||
Precision prez = new Precision();
|
||||
expectThrows(IllegalArgumentException.class, () -> prez.setRelevantRatingThreshhold(-1));
|
||||
}
|
||||
|
||||
public static Precision createTestItem() {
|
||||
Precision precision = new Precision();
|
||||
if (randomBoolean()) {
|
||||
|
@ -186,6 +191,34 @@ public class PrecisionTests extends ESTestCase {
|
|||
assertEquals(testItem, parsedItem);
|
||||
assertEquals(testItem.hashCode(), parsedItem.hashCode());
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
Precision original = createTestItem();
|
||||
Precision deserialized = RankEvalTestHelper.copy(original, Precision::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
Precision testItem = createTestItem();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
RankEvalTestHelper.copy(testItem, Precision::new));
|
||||
}
|
||||
|
||||
private Precision mutateTestItem(Precision original) {
|
||||
boolean ignoreUnlabeled = original.getIgnoreUnlabeled();
|
||||
int relevantThreshold = original.getRelevantRatingThreshold();
|
||||
Precision precision = new Precision();
|
||||
precision.setIgnoreUnlabeled(ignoreUnlabeled);
|
||||
precision.setRelevantRatingThreshhold(relevantThreshold);
|
||||
|
||||
List<Runnable> mutators = new ArrayList<>();
|
||||
mutators.add(() -> precision.setIgnoreUnlabeled(! ignoreUnlabeled));
|
||||
mutators.add(() -> precision.setRelevantRatingThreshhold(randomValueOtherThan(relevantThreshold, () -> randomIntBetween(0, 10))));
|
||||
randomFrom(mutators).run();
|
||||
return precision;
|
||||
}
|
||||
|
||||
private static SearchHit[] toSearchHits(List<RatedDocument> rated, String index, String type) {
|
||||
InternalSearchHit[] hits = new InternalSearchHit[rated.size()];
|
||||
|
|
|
@ -44,4 +44,52 @@ public class RatedDocumentTests extends ESTestCase {
|
|||
assertEquals(testItem, parsedItem);
|
||||
assertEquals(testItem.hashCode(), parsedItem.hashCode());
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
RatedDocument original = createRatedDocument();
|
||||
RatedDocument deserialized = RankEvalTestHelper.copy(original, RatedDocument::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
RatedDocument testItem = createRatedDocument();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
RankEvalTestHelper.copy(testItem, RatedDocument::new));
|
||||
}
|
||||
|
||||
public void testInvalidParsing() throws IOException {
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedDocument(null, "abc", "abc", 10));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("", "abc", "abc", 10));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", null, "abc", 10));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", "", "abc", 10));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", "abc", null, 10));
|
||||
expectThrows(IllegalArgumentException.class, () -> new RatedDocument("abc", "abc", "", 10));
|
||||
}
|
||||
|
||||
private static RatedDocument mutateTestItem(RatedDocument original) {
|
||||
int rating = original.getRating();
|
||||
String index = original.getIndex();
|
||||
String type = original.getType();
|
||||
String docId = original.getDocID();
|
||||
|
||||
switch (randomIntBetween(0, 3)) {
|
||||
case 0:
|
||||
rating = randomValueOtherThan(rating, () -> randomInt());
|
||||
break;
|
||||
case 1:
|
||||
index = randomValueOtherThan(index, () -> randomAsciiOfLength(10));
|
||||
break;
|
||||
case 2:
|
||||
type = randomValueOtherThan(type, () -> randomAsciiOfLength(10));
|
||||
break;
|
||||
case 3:
|
||||
docId = randomValueOtherThan(docId, () -> randomAsciiOfLength(10));
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("The test should only allow two parameters mutated");
|
||||
}
|
||||
return new RatedDocument(index, type, docId, rating);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,11 +20,13 @@
|
|||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.common.ParseFieldMatcher;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.ParseFieldRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryParseContext;
|
||||
import org.elasticsearch.indices.query.IndicesQueriesRegistry;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
|
@ -131,6 +133,92 @@ public class RatedRequestsTests extends ESTestCase {
|
|||
assertEquals(testItem, parsedItem);
|
||||
assertEquals(testItem.hashCode(), parsedItem.hashCode());
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
List<String> indices = new ArrayList<>();
|
||||
int size = randomIntBetween(0, 20);
|
||||
for (int i = 0; i < size; i++) {
|
||||
indices.add(randomAsciiOfLengthBetween(0, 50));
|
||||
}
|
||||
|
||||
List<String> types = new ArrayList<>();
|
||||
size = randomIntBetween(0, 20);
|
||||
for (int i = 0; i < size; i++) {
|
||||
types.add(randomAsciiOfLengthBetween(0, 50));
|
||||
}
|
||||
|
||||
RatedRequest original = createTestItem(indices, types);
|
||||
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
|
||||
|
||||
RatedRequest deserialized = RankEvalTestHelper.copy(original, RatedRequest::new, new NamedWriteableRegistry(namedWriteables));
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
List<String> indices = new ArrayList<>();
|
||||
int size = randomIntBetween(0, 20);
|
||||
for (int i = 0; i < size; i++) {
|
||||
indices.add(randomAsciiOfLengthBetween(0, 50));
|
||||
}
|
||||
|
||||
List<String> types = new ArrayList<>();
|
||||
size = randomIntBetween(0, 20);
|
||||
for (int i = 0; i < size; i++) {
|
||||
types.add(randomAsciiOfLengthBetween(0, 50));
|
||||
}
|
||||
|
||||
RatedRequest testItem = createTestItem(indices, types);
|
||||
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new));
|
||||
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
RankEvalTestHelper.copy(testItem, RatedRequest::new, new NamedWriteableRegistry(namedWriteables)));
|
||||
}
|
||||
|
||||
private RatedRequest mutateTestItem(RatedRequest original) {
|
||||
String specId = original.getSpecId();
|
||||
int size = original.getTestRequest().size();
|
||||
List<RatedDocument> ratedDocs = original.getRatedDocs();
|
||||
List<String> indices = original.getIndices();
|
||||
List<String> types = original.getTypes();
|
||||
Map<String, Object> params = original.getParams();
|
||||
List<String> summaryFields = original.getSummaryFields();
|
||||
|
||||
SearchSourceBuilder testRequest = new SearchSourceBuilder();
|
||||
testRequest.size(size);
|
||||
testRequest.query(new MatchAllQueryBuilder());
|
||||
|
||||
RatedRequest ratedRequest = new RatedRequest(specId, testRequest, indices, types, ratedDocs);
|
||||
ratedRequest.setIndices(indices);
|
||||
ratedRequest.setTypes(types);
|
||||
ratedRequest.setParams(params);
|
||||
ratedRequest.setSummaryFields(summaryFields);
|
||||
|
||||
List<Runnable> mutators = new ArrayList<>();
|
||||
mutators.add(() -> ratedRequest.setSpecId(randomValueOtherThan(specId, () -> randomAsciiOfLength(10))));
|
||||
mutators.add(() -> ratedRequest.getTestRequest().size(randomValueOtherThan(size, () -> randomInt())));
|
||||
mutators.add(() -> ratedRequest.setRatedDocs(
|
||||
Arrays.asList(randomValueOtherThanMany(ratedDocs::contains, () -> RatedDocumentTests.createRatedDocument()))));
|
||||
mutators.add(() -> ratedRequest.setIndices(
|
||||
Arrays.asList(randomValueOtherThanMany(indices::contains, () -> randomAsciiOfLength(10)))));
|
||||
|
||||
HashMap<String, Object> modified = new HashMap<>();
|
||||
modified.putAll(params);
|
||||
modified.put("one_more_key", "one_more_value");
|
||||
mutators.add(() -> ratedRequest.setParams(modified));
|
||||
|
||||
mutators.add(() -> ratedRequest.setSummaryFields(
|
||||
Arrays.asList(randomValueOtherThanMany(summaryFields::contains, () -> randomAsciiOfLength(10)))));
|
||||
|
||||
|
||||
randomFrom(mutators).run();
|
||||
return ratedRequest;
|
||||
}
|
||||
|
||||
public void testDuplicateRatedDocThrowsException() {
|
||||
RatedRequest request = createTestItem(Arrays.asList("index"), Arrays.asList("type"));
|
||||
|
|
|
@ -123,8 +123,7 @@ public class ReciprocalRankTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testXContentRoundtrip() throws IOException {
|
||||
ReciprocalRank testItem = new ReciprocalRank();
|
||||
testItem.setRelevantRatingThreshhold(randomIntBetween(0, 20));
|
||||
ReciprocalRank testItem = createTestItem();
|
||||
XContentParser itemParser = RankEvalTestHelper.roundtrip(testItem);
|
||||
itemParser.nextToken();
|
||||
itemParser.nextToken();
|
||||
|
@ -146,4 +145,37 @@ public class ReciprocalRankTests extends ESTestCase {
|
|||
}
|
||||
return hits;
|
||||
}
|
||||
|
||||
private ReciprocalRank createTestItem() {
|
||||
ReciprocalRank testItem = new ReciprocalRank();
|
||||
testItem.setRelevantRatingThreshhold(randomIntBetween(0, 20));
|
||||
return testItem;
|
||||
}
|
||||
|
||||
public void testSerialization() throws IOException {
|
||||
ReciprocalRank original = createTestItem();
|
||||
|
||||
ReciprocalRank deserialized = RankEvalTestHelper.copy(original, ReciprocalRank::new);
|
||||
assertEquals(deserialized, original);
|
||||
assertEquals(deserialized.hashCode(), original.hashCode());
|
||||
assertNotSame(deserialized, original);
|
||||
}
|
||||
|
||||
public void testEqualsAndHash() throws IOException {
|
||||
ReciprocalRank testItem = createTestItem();
|
||||
RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(testItem),
|
||||
RankEvalTestHelper.copy(testItem, ReciprocalRank::new));
|
||||
}
|
||||
|
||||
private ReciprocalRank mutateTestItem(ReciprocalRank testItem) {
|
||||
int relevantThreshold = testItem.getRelevantRatingThreshold();
|
||||
ReciprocalRank rank = new ReciprocalRank();
|
||||
rank.setRelevantRatingThreshhold(randomValueOtherThan(relevantThreshold, () -> randomIntBetween(0, 10)));
|
||||
return rank;
|
||||
}
|
||||
|
||||
public void testInvalidRelevantThreshold() {
|
||||
ReciprocalRank prez = new ReciprocalRank();
|
||||
expectThrows(IllegalArgumentException.class, () -> prez.setRelevantRatingThreshhold(-1));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue