diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java index 8e0828fcfca..22875139c9b 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java @@ -57,7 +57,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject { /** Default max number of requests. */ private static final int MAX_CONCURRENT_SEARCHES = 10; /** optional: Templates to base test requests on */ - private Map templates = new HashMap<>(); + private final Map templates = new HashMap<>(); public RankEvalSpec(List ratedRequests, EvaluationMetric metric, Collection templates) { this.metric = Objects.requireNonNull(metric, "Cannot evaluate ranking if no evaluation metric is provided."); @@ -68,8 +68,8 @@ public class RankEvalSpec implements Writeable, ToXContentObject { this.ratedRequests = ratedRequests; if (templates == null || templates.isEmpty()) { for (RatedRequest request : ratedRequests) { - if (request.getTestRequest() == null) { - throw new IllegalStateException("Cannot evaluate ranking if neither template nor test request is " + if (request.getEvaluationRequest() == null) { + throw new IllegalStateException("Cannot evaluate ranking if neither template nor evaluation request is " + "provided. Seen for request id: " + request.getId()); } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java index 392ce5d0633..79dd693b3ac 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java @@ -75,9 +75,12 @@ public class RatedRequest implements Writeable, ToXContentObject { private final String id; private final List summaryFields; private final List ratedDocs; - // Search request to execute for this rated request. This can be null if template and corresponding parameters are supplied. + /** + * Search request to execute for this rated request. This can be null in + * case the query is supplied as a template with corresponding parameters + */ @Nullable - private SearchSourceBuilder testRequest; + private final SearchSourceBuilder evaluationRequest; /** * Map of parameters to use for filling a query template, can be used * instead of providing testRequest. @@ -86,27 +89,49 @@ public class RatedRequest implements Writeable, ToXContentObject { @Nullable private String templateId; - private RatedRequest(String id, List ratedDocs, SearchSourceBuilder testRequest, + /** + * Create a rated request with template ids and parameters. + * + * @param id a unique name for this rated request + * @param ratedDocs a list of document ratings + * @param params template parameters + * @param templateId a templare id + */ + public RatedRequest(String id, List ratedDocs, Map params, + String templateId) { + this(id, ratedDocs, null, params, templateId); + } + + /** + * Create a rated request using a {@link SearchSourceBuilder} to define the + * evaluated query. + * + * @param id a unique name for this rated request + * @param ratedDocs a list of document ratings + * @param evaluatedQuery the query that is evaluated + */ + public RatedRequest(String id, List ratedDocs, SearchSourceBuilder evaluatedQuery) { + this(id, ratedDocs, evaluatedQuery, new HashMap<>(), null); + } + + private RatedRequest(String id, List ratedDocs, SearchSourceBuilder evaluatedQuery, Map params, String templateId) { - if (params != null && (params.size() > 0 && testRequest != null)) { + if (params != null && (params.size() > 0 && evaluatedQuery != null)) { throw new IllegalArgumentException( - "Ambiguous rated request: Set both, verbatim test request and test request " - + "template parameters."); + "Ambiguous rated request: Set both, verbatim test request and test request " + "template parameters."); } - if (templateId != null && testRequest != null) { + if (templateId != null && evaluatedQuery != null) { throw new IllegalArgumentException( - "Ambiguous rated request: Set both, verbatim test request and test request " - + "template parameters."); + "Ambiguous rated request: Set both, verbatim test request and test request " + "template parameters."); } - if ((params == null || params.size() < 1) && testRequest == null) { - throw new IllegalArgumentException( - "Need to set at least test request or test request template parameters."); + if ((params == null || params.size() < 1) && evaluatedQuery == null) { + throw new IllegalArgumentException("Need to set at least test request or test request template parameters."); } if ((params != null && params.size() > 0) && templateId == null) { - throw new IllegalArgumentException( - "If template parameters are supplied need to set id of template to apply " - + "them to too."); + throw new IllegalArgumentException("If template parameters are supplied need to set id of template to apply " + "them to too."); } + validateEvaluatedQuery(evaluatedQuery); + // check that not two documents with same _index/id are specified Set docKeys = new HashSet<>(); for (RatedDocument doc : ratedDocs) { @@ -118,7 +143,7 @@ public class RatedRequest implements Writeable, ToXContentObject { } this.id = id; - this.testRequest = testRequest; + this.evaluationRequest = evaluatedQuery; this.ratedDocs = new ArrayList<>(ratedDocs); if (params != null) { this.params = new HashMap<>(params); @@ -129,18 +154,30 @@ public class RatedRequest implements Writeable, ToXContentObject { this.summaryFields = new ArrayList<>(); } - public RatedRequest(String id, List ratedDocs, Map params, - String templateId) { - this(id, ratedDocs, null, params, templateId); + static void validateEvaluatedQuery(SearchSourceBuilder evaluationRequest) { + // ensure that testRequest, if set, does not contain aggregation, suggest or highlighting section + if (evaluationRequest != null) { + if (evaluationRequest.suggest() != null) { + throw new IllegalArgumentException("Query in rated requests should not contain a suggest section."); + } + if (evaluationRequest.aggregations() != null) { + throw new IllegalArgumentException("Query in rated requests should not contain aggregations."); + } + if (evaluationRequest.highlighter() != null) { + throw new IllegalArgumentException("Query in rated requests should not contain a highlighter section."); + } + if (evaluationRequest.explain() != null && evaluationRequest.explain()) { + throw new IllegalArgumentException("Query in rated requests should not use explain."); + } + if (evaluationRequest.profile()) { + throw new IllegalArgumentException("Query in rated requests should not use profile."); + } + } } - public RatedRequest(String id, List ratedDocs, SearchSourceBuilder testRequest) { - this(id, ratedDocs, testRequest, new HashMap<>(), null); - } - - public RatedRequest(StreamInput in) throws IOException { + RatedRequest(StreamInput in) throws IOException { this.id = in.readString(); - testRequest = in.readOptionalWriteable(SearchSourceBuilder::new); + evaluationRequest = in.readOptionalWriteable(SearchSourceBuilder::new); int intentSize = in.readInt(); ratedDocs = new ArrayList<>(intentSize); @@ -159,7 +196,7 @@ public class RatedRequest implements Writeable, ToXContentObject { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(id); - out.writeOptionalWriteable(testRequest); + out.writeOptionalWriteable(evaluationRequest); out.writeInt(ratedDocs.size()); for (RatedDocument ratedDoc : ratedDocs) { @@ -173,8 +210,8 @@ public class RatedRequest implements Writeable, ToXContentObject { out.writeOptionalString(this.templateId); } - public SearchSourceBuilder getTestRequest() { - return testRequest; + public SearchSourceBuilder getEvaluationRequest() { + return evaluationRequest; } /** return the user supplied request id */ @@ -240,8 +277,8 @@ public class RatedRequest implements Writeable, ToXContentObject { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(ID_FIELD.getPreferredName(), this.id); - if (testRequest != null) { - builder.field(REQUEST_FIELD.getPreferredName(), this.testRequest); + if (evaluationRequest != null) { + builder.field(REQUEST_FIELD.getPreferredName(), this.evaluationRequest); } builder.startArray(RATINGS_FIELD.getPreferredName()); for (RatedDocument doc : this.ratedDocs) { @@ -285,7 +322,7 @@ public class RatedRequest implements Writeable, ToXContentObject { RatedRequest other = (RatedRequest) obj; - return Objects.equals(id, other.id) && Objects.equals(testRequest, other.testRequest) + return Objects.equals(id, other.id) && Objects.equals(evaluationRequest, other.evaluationRequest) && Objects.equals(summaryFields, other.summaryFields) && Objects.equals(ratedDocs, other.ratedDocs) && Objects.equals(params, other.params) @@ -294,7 +331,7 @@ public class RatedRequest implements Writeable, ToXContentObject { @Override public final int hashCode() { - return Objects.hash(id, testRequest, summaryFields, ratedDocs, params, + return Objects.hash(id, evaluationRequest, summaryFields, ratedDocs, params, templateId); } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java index 019ae274466..e0a0b3ea133 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java @@ -52,6 +52,7 @@ import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; import static org.elasticsearch.common.xcontent.XContentHelper.createParser; +import static org.elasticsearch.index.rankeval.RatedRequest.validateEvaluatedQuery; /** * Instances of this class execute a collection of search intents (read: user @@ -99,15 +100,17 @@ public class TransportRankEvalAction extends HandledTransportAction ratedRequestsInSearch = new ArrayList<>(); for (RatedRequest ratedRequest : ratedRequests) { - SearchSourceBuilder ratedSearchSource = ratedRequest.getTestRequest(); - if (ratedSearchSource == null) { + SearchSourceBuilder evaluationRequest = ratedRequest.getEvaluationRequest(); + if (evaluationRequest == null) { Map params = ratedRequest.getParams(); String templateId = ratedRequest.getTemplateId(); TemplateScript.Factory templateScript = scriptsWithoutParams.get(templateId); String resolvedRequest = templateScript.newInstance(params).execute(); try (XContentParser subParser = createParser(namedXContentRegistry, LoggingDeprecationHandler.INSTANCE, new BytesArray(resolvedRequest), XContentType.JSON)) { - ratedSearchSource = SearchSourceBuilder.fromXContent(subParser, false); + evaluationRequest = SearchSourceBuilder.fromXContent(subParser, false); + // check for parts that should not be part of a ranking evaluation request + validateEvaluatedQuery(evaluationRequest); } catch (IOException e) { // if we fail parsing, put the exception into the errors map and continue errors.put(ratedRequest.getId(), e); @@ -116,17 +119,17 @@ public class TransportRankEvalAction extends HandledTransportAction summaryFields = ratedRequest.getSummaryFields(); if (summaryFields.isEmpty()) { - ratedSearchSource.fetchSource(false); + evaluationRequest.fetchSource(false); } else { - ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]); + evaluationRequest.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]); } - SearchRequest searchRequest = new SearchRequest(request.indices(), ratedSearchSource); + SearchRequest searchRequest = new SearchRequest(request.indices(), evaluationRequest); searchRequest.indicesOptions(request.indicesOptions()); msearchRequest.add(searchRequest); } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java index 196b50b7f61..084f29b8c9a 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java @@ -33,7 +33,11 @@ import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.suggest.SuggestBuilder; +import org.elasticsearch.search.suggest.SuggestBuilders; import org.elasticsearch.test.ESTestCase; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -165,7 +169,7 @@ public class RatedRequestsTests extends ESTestCase { private static RatedRequest mutateTestItem(RatedRequest original) { String id = original.getId(); - SearchSourceBuilder testRequest = original.getTestRequest(); + SearchSourceBuilder evaluationRequest = original.getEvaluationRequest(); List ratedDocs = original.getRatedDocs(); Map params = original.getParams(); List summaryFields = original.getSummaryFields(); @@ -177,11 +181,11 @@ public class RatedRequestsTests extends ESTestCase { id = randomValueOtherThan(id, () -> randomAlphaOfLength(10)); break; case 1: - if (testRequest != null) { - int size = randomValueOtherThan(testRequest.size(), () -> randomInt(Integer.MAX_VALUE)); - testRequest = new SearchSourceBuilder(); - testRequest.size(size); - testRequest.query(new MatchAllQueryBuilder()); + if (evaluationRequest != null) { + int size = randomValueOtherThan(evaluationRequest.size(), () -> randomInt(Integer.MAX_VALUE)); + evaluationRequest = new SearchSourceBuilder(); + evaluationRequest.size(size); + evaluationRequest.query(new MatchAllQueryBuilder()); } else { if (randomBoolean()) { Map mutated = new HashMap<>(); @@ -204,10 +208,10 @@ public class RatedRequestsTests extends ESTestCase { } RatedRequest ratedRequest; - if (testRequest == null) { + if (evaluationRequest == null) { ratedRequest = new RatedRequest(id, ratedDocs, params, templateId); } else { - ratedRequest = new RatedRequest(id, ratedDocs, testRequest); + ratedRequest = new RatedRequest(id, ratedDocs, evaluationRequest); } ratedRequest.addSummaryFields(summaryFields); @@ -258,6 +262,44 @@ public class RatedRequestsTests extends ESTestCase { expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, "templateId")); } + public void testAggsNotAllowed() { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + SearchSourceBuilder query = new SearchSourceBuilder(); + query.aggregation(AggregationBuilders.terms("fieldName")); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, query)); + assertEquals("Query in rated requests should not contain aggregations.", e.getMessage()); + } + + public void testSuggestionsNotAllowed() { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + SearchSourceBuilder query = new SearchSourceBuilder(); + query.suggest(new SuggestBuilder().addSuggestion("id", SuggestBuilders.completionSuggestion("fieldname"))); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, query)); + assertEquals("Query in rated requests should not contain a suggest section.", e.getMessage()); + } + + public void testHighlighterNotAllowed() { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + SearchSourceBuilder query = new SearchSourceBuilder(); + query.highlighter(new HighlightBuilder().field("field")); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, query)); + assertEquals("Query in rated requests should not contain a highlighter section.", e.getMessage()); + } + + public void testExplainNotAllowed() { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder().explain(true))); + assertEquals("Query in rated requests should not use explain.", e.getMessage()); + } + + public void testProfileNotAllowed() { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder().profile(true))); + assertEquals("Query in rated requests should not use profile.", e.getMessage()); + } + /** * test that modifying the order of index/docId to make sure it doesn't * matter for parsing xContent @@ -287,7 +329,7 @@ public class RatedRequestsTests extends ESTestCase { try (XContentParser parser = createParser(JsonXContent.jsonXContent, querySpecString)) { RatedRequest specification = RatedRequest.fromXContent(parser); assertEquals("my_qa_query", specification.getId()); - assertNotNull(specification.getTestRequest()); + assertNotNull(specification.getEvaluationRequest()); List ratedDocs = specification.getRatedDocs(); assertEquals(3, ratedDocs.size()); for (int i = 0; i < 3; i++) { diff --git a/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/index/rankeval/SmokeMultipleTemplatesIT.java b/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/index/rankeval/SmokeMultipleTemplatesIT.java index 50860ddd87b..0ad78ad0c7a 100644 --- a/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/index/rankeval/SmokeMultipleTemplatesIT.java +++ b/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/index/rankeval/SmokeMultipleTemplatesIT.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -106,6 +107,43 @@ public class SmokeMultipleTemplatesIT extends ESIntegTestCase { assertEquals(0.9, response.getEvaluationResult(), Double.MIN_VALUE); } + public void testTemplateWithAggsFails() { + String template = "{ \"aggs\" : { \"avg_grade\" : { \"avg\" : { \"field\" : \"grade\" }}}}"; + assertTemplatedRequestFailures(template, "Query in rated requests should not contain aggregations."); + } + + public void testTemplateWithSuggestFails() { + String template = "{\"suggest\" : {\"my-suggestion\" : {\"text\" : \"Elastic\",\"term\" : {\"field\" : \"message\"}}}}"; + assertTemplatedRequestFailures(template, "Query in rated requests should not contain a suggest section."); + } + + public void testTemplateWithHighlighterFails() { + String template = "{\"highlight\" : { \"fields\" : {\"content\" : {}}}}"; + assertTemplatedRequestFailures(template, "Query in rated requests should not contain a highlighter section."); + } + + public void testTemplateWithProfileFails() { + String template = "{\"profile\" : \"true\" }"; + assertTemplatedRequestFailures(template, "Query in rated requests should not use profile."); + } + + public void testTemplateWithExplainFails() { + String template = "{\"explain\" : \"true\" }"; + assertTemplatedRequestFailures(template, "Query in rated requests should not use explain."); + } + + private static void assertTemplatedRequestFailures(String template, String expectedMessage) { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + RatedRequest ratedRequest = new RatedRequest("id", ratedDocs, Collections.singletonMap("param1", "value1"), "templateId"); + Collection templates = Collections.singletonList(new ScriptWithId("templateId", + new Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG, template, Collections.emptyMap()))); + RankEvalSpec rankEvalSpec = new RankEvalSpec(Collections.singletonList(ratedRequest), new PrecisionAtK(), templates); + RankEvalRequest rankEvalRequest = new RankEvalRequest(rankEvalSpec, new String[] { "test" }); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> client().execute(RankEvalAction.INSTANCE, rankEvalRequest).actionGet()); + assertEquals(expectedMessage, e.getMessage()); + } + private static List createRelevant(String... docs) { List relevant = new ArrayList<>(); for (String doc : docs) {