diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java index acdbcd7ee60..3386caa38ee 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.index.rankeval; import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.ParseFieldRegistry; import org.elasticsearch.common.xcontent.ToXContent; @@ -27,6 +28,8 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.index.rankeval.RankEvalSpec.ScriptWithId; import org.elasticsearch.indices.query.IndicesQueriesRegistry; @@ -50,6 +53,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.function.Supplier; import static java.util.Collections.emptyList; @@ -131,7 +135,7 @@ public class RankEvalSpecTests extends ESTestCase { return spec; } - public void testRoundtripping() throws IOException { + public void testXContentRoundtrip() throws IOException { RankEvalSpec testItem = createTestItem(); XContentBuilder shuffled = ESTestCase.shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); @@ -149,6 +153,83 @@ public class RankEvalSpecTests extends ESTestCase { assertEquals(testItem.hashCode(), parsedItem.hashCode()); } + public void testSerialization() throws IOException { + RankEvalSpec original = createTestItem(); + + List namedWriteables = new ArrayList<>(); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, Precision.NAME, Precision::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry( + RankedListQualityMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, ReciprocalRank.NAME, ReciprocalRank::new)); + + + RankEvalSpec deserialized = RankEvalTestHelper.copy(original, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables)); + assertEquals(deserialized, original); + assertEquals(deserialized.hashCode(), original.hashCode()); + assertNotSame(deserialized, original); + } + + public void testEqualsAndHash() throws IOException { + RankEvalSpec testItem = createTestItem(); + + List namedWriteables = new ArrayList<>(); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, Precision.NAME, Precision::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry( + RankedListQualityMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(RankedListQualityMetric.class, ReciprocalRank.NAME, ReciprocalRank::new)); + + RankEvalSpec mutant = RankEvalTestHelper.copy(testItem, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables)); + RankEvalTestHelper.testHashCodeAndEquals(testItem, mutateTestItem(mutant), + RankEvalTestHelper.copy(testItem, RankEvalSpec::new, new NamedWriteableRegistry(namedWriteables))); + } + + private RankEvalSpec mutateTestItem(RankEvalSpec mutant) { + Collection ratedRequests = mutant.getRatedRequests(); + RankedListQualityMetric metric = mutant.getMetric(); + Map templates = mutant.getTemplates(); + + int mutate = randomIntBetween(0, 2); + switch (mutate) { + case 0: + RatedRequest request = RatedRequestsTests.createTestItem(new ArrayList<>(), new ArrayList<>()); + ratedRequests.add(request); + break; + case 1: + if (metric instanceof Precision) { + metric = new DiscountedCumulativeGain(); + } else { + metric = new Precision(); + } + break; + case 2: + if (templates.size() > 0) { + if (randomBoolean()) { + templates = null; + } else { + String mutatedTemplate = randomAsciiOfLength(10); + templates.put("mutation", new Script(ScriptType.INLINE, "mustache", mutatedTemplate, new HashMap<>())); + + } + } else { + String mutatedTemplate = randomValueOtherThanMany(templates::containsValue, () -> randomAsciiOfLength(10)); + templates.put("mutation", new Script(ScriptType.INLINE, "mustache", mutatedTemplate, new HashMap<>())); + } + break; + default: + throw new IllegalStateException("Requested to modify more than available parameters."); + } + + List scripts = new ArrayList<>(); + for (Entry entry : templates.entrySet()) { + scripts.add(new ScriptWithId(entry.getKey(), entry.getValue())); + } + + RankEvalSpec result = new RankEvalSpec(ratedRequests, metric, scripts); + return result; + } + public void testMissingRatedRequestsFailsParsing() { RankedListQualityMetric metric = new Precision(); expectThrows(IllegalStateException.class, () -> new RankEvalSpec(new ArrayList<>(), metric)); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java index 67b7a65e338..4844c21c299 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java @@ -202,16 +202,27 @@ public class RatedRequestsTests extends ESTestCase { List summaryFields = original.getSummaryFields(); String templateId = original.getTemplateId(); - int mutate = randomIntBetween(0, 7); + int mutate = randomIntBetween(0, 5); switch (mutate) { case 0: id = randomValueOtherThan(id, () -> randomAsciiOfLength(10)); break; case 1: - int size = randomValueOtherThan(testRequest.size(), () -> randomInt()); - testRequest = new SearchSourceBuilder(); - testRequest.size(size); - testRequest.query(new MatchAllQueryBuilder()); + if (testRequest != null) { + int size = randomValueOtherThan(testRequest.size(), () -> randomInt()); + testRequest = new SearchSourceBuilder(); + testRequest.size(size); + testRequest.query(new MatchAllQueryBuilder()); + } else { + if (randomBoolean()) { + Map mutated = new HashMap<>(); + mutated.putAll(params); + mutated.put("one_more_key", "one_more_value"); + params = mutated; + } else { + templateId = randomValueOtherThan(templateId, () -> randomAsciiOfLength(5)); + } + } break; case 2: ratedDocs = Arrays.asList( @@ -224,16 +235,8 @@ public class RatedRequestsTests extends ESTestCase { types = Arrays.asList(randomValueOtherThanMany(types::contains, () -> randomAsciiOfLength(10))); break; case 5: - params = new HashMap<>(); - params.putAll(params); - params.put("one_more_key", "one_more_value"); - break; - case 6: summaryFields = Arrays.asList(randomValueOtherThanMany(summaryFields::contains, () -> randomAsciiOfLength(10))); break; - case 7: - templateId = randomValueOtherThan(templateId, () -> randomAsciiOfLength(5)); - break; default: throw new IllegalStateException("Requested to modify more than available parameters."); }