From cc93131318057f53ef59efe6beceb50e4d139406 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Mon, 14 May 2018 17:36:26 +0200 Subject: [PATCH] Forbid expensive query parts in ranking evaluation (#30151) Currently the ranking evaluation API accepts the full query syntax for the queries specified in the evaluation set and executes them via multi search. This potentially runs costly aggregations and suggestions too. This change adds checks that forbid using aggregations, suggesters, highlighters and the explain and profile options in the queries that are run as part of the ranking evaluation since they are irrelevent in the context of this API. --- .../index/rankeval/RankEvalSpec.java | 6 +- .../index/rankeval/RatedRequest.java | 101 ++++++++++++------ .../rankeval/TransportRankEvalAction.java | 17 +-- .../index/rankeval/RatedRequestsTests.java | 60 +++++++++-- .../rankeval/SmokeMultipleTemplatesIT.java | 38 +++++++ 5 files changed, 171 insertions(+), 51 deletions(-) diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java index 8e0828fcfca..22875139c9b 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java @@ -57,7 +57,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject { /** Default max number of requests. */ private static final int MAX_CONCURRENT_SEARCHES = 10; /** optional: Templates to base test requests on */ - private Map templates = new HashMap<>(); + private final Map templates = new HashMap<>(); public RankEvalSpec(List ratedRequests, EvaluationMetric metric, Collection templates) { this.metric = Objects.requireNonNull(metric, "Cannot evaluate ranking if no evaluation metric is provided."); @@ -68,8 +68,8 @@ public class RankEvalSpec implements Writeable, ToXContentObject { this.ratedRequests = ratedRequests; if (templates == null || templates.isEmpty()) { for (RatedRequest request : ratedRequests) { - if (request.getTestRequest() == null) { - throw new IllegalStateException("Cannot evaluate ranking if neither template nor test request is " + if (request.getEvaluationRequest() == null) { + throw new IllegalStateException("Cannot evaluate ranking if neither template nor evaluation request is " + "provided. Seen for request id: " + request.getId()); } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java index 392ce5d0633..79dd693b3ac 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java @@ -75,9 +75,12 @@ public class RatedRequest implements Writeable, ToXContentObject { private final String id; private final List summaryFields; private final List ratedDocs; - // Search request to execute for this rated request. This can be null if template and corresponding parameters are supplied. + /** + * Search request to execute for this rated request. This can be null in + * case the query is supplied as a template with corresponding parameters + */ @Nullable - private SearchSourceBuilder testRequest; + private final SearchSourceBuilder evaluationRequest; /** * Map of parameters to use for filling a query template, can be used * instead of providing testRequest. @@ -86,27 +89,49 @@ public class RatedRequest implements Writeable, ToXContentObject { @Nullable private String templateId; - private RatedRequest(String id, List ratedDocs, SearchSourceBuilder testRequest, + /** + * Create a rated request with template ids and parameters. + * + * @param id a unique name for this rated request + * @param ratedDocs a list of document ratings + * @param params template parameters + * @param templateId a templare id + */ + public RatedRequest(String id, List ratedDocs, Map params, + String templateId) { + this(id, ratedDocs, null, params, templateId); + } + + /** + * Create a rated request using a {@link SearchSourceBuilder} to define the + * evaluated query. + * + * @param id a unique name for this rated request + * @param ratedDocs a list of document ratings + * @param evaluatedQuery the query that is evaluated + */ + public RatedRequest(String id, List ratedDocs, SearchSourceBuilder evaluatedQuery) { + this(id, ratedDocs, evaluatedQuery, new HashMap<>(), null); + } + + private RatedRequest(String id, List ratedDocs, SearchSourceBuilder evaluatedQuery, Map params, String templateId) { - if (params != null && (params.size() > 0 && testRequest != null)) { + if (params != null && (params.size() > 0 && evaluatedQuery != null)) { throw new IllegalArgumentException( - "Ambiguous rated request: Set both, verbatim test request and test request " - + "template parameters."); + "Ambiguous rated request: Set both, verbatim test request and test request " + "template parameters."); } - if (templateId != null && testRequest != null) { + if (templateId != null && evaluatedQuery != null) { throw new IllegalArgumentException( - "Ambiguous rated request: Set both, verbatim test request and test request " - + "template parameters."); + "Ambiguous rated request: Set both, verbatim test request and test request " + "template parameters."); } - if ((params == null || params.size() < 1) && testRequest == null) { - throw new IllegalArgumentException( - "Need to set at least test request or test request template parameters."); + if ((params == null || params.size() < 1) && evaluatedQuery == null) { + throw new IllegalArgumentException("Need to set at least test request or test request template parameters."); } if ((params != null && params.size() > 0) && templateId == null) { - throw new IllegalArgumentException( - "If template parameters are supplied need to set id of template to apply " - + "them to too."); + throw new IllegalArgumentException("If template parameters are supplied need to set id of template to apply " + "them to too."); } + validateEvaluatedQuery(evaluatedQuery); + // check that not two documents with same _index/id are specified Set docKeys = new HashSet<>(); for (RatedDocument doc : ratedDocs) { @@ -118,7 +143,7 @@ public class RatedRequest implements Writeable, ToXContentObject { } this.id = id; - this.testRequest = testRequest; + this.evaluationRequest = evaluatedQuery; this.ratedDocs = new ArrayList<>(ratedDocs); if (params != null) { this.params = new HashMap<>(params); @@ -129,18 +154,30 @@ public class RatedRequest implements Writeable, ToXContentObject { this.summaryFields = new ArrayList<>(); } - public RatedRequest(String id, List ratedDocs, Map params, - String templateId) { - this(id, ratedDocs, null, params, templateId); + static void validateEvaluatedQuery(SearchSourceBuilder evaluationRequest) { + // ensure that testRequest, if set, does not contain aggregation, suggest or highlighting section + if (evaluationRequest != null) { + if (evaluationRequest.suggest() != null) { + throw new IllegalArgumentException("Query in rated requests should not contain a suggest section."); + } + if (evaluationRequest.aggregations() != null) { + throw new IllegalArgumentException("Query in rated requests should not contain aggregations."); + } + if (evaluationRequest.highlighter() != null) { + throw new IllegalArgumentException("Query in rated requests should not contain a highlighter section."); + } + if (evaluationRequest.explain() != null && evaluationRequest.explain()) { + throw new IllegalArgumentException("Query in rated requests should not use explain."); + } + if (evaluationRequest.profile()) { + throw new IllegalArgumentException("Query in rated requests should not use profile."); + } + } } - public RatedRequest(String id, List ratedDocs, SearchSourceBuilder testRequest) { - this(id, ratedDocs, testRequest, new HashMap<>(), null); - } - - public RatedRequest(StreamInput in) throws IOException { + RatedRequest(StreamInput in) throws IOException { this.id = in.readString(); - testRequest = in.readOptionalWriteable(SearchSourceBuilder::new); + evaluationRequest = in.readOptionalWriteable(SearchSourceBuilder::new); int intentSize = in.readInt(); ratedDocs = new ArrayList<>(intentSize); @@ -159,7 +196,7 @@ public class RatedRequest implements Writeable, ToXContentObject { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(id); - out.writeOptionalWriteable(testRequest); + out.writeOptionalWriteable(evaluationRequest); out.writeInt(ratedDocs.size()); for (RatedDocument ratedDoc : ratedDocs) { @@ -173,8 +210,8 @@ public class RatedRequest implements Writeable, ToXContentObject { out.writeOptionalString(this.templateId); } - public SearchSourceBuilder getTestRequest() { - return testRequest; + public SearchSourceBuilder getEvaluationRequest() { + return evaluationRequest; } /** return the user supplied request id */ @@ -240,8 +277,8 @@ public class RatedRequest implements Writeable, ToXContentObject { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(ID_FIELD.getPreferredName(), this.id); - if (testRequest != null) { - builder.field(REQUEST_FIELD.getPreferredName(), this.testRequest); + if (evaluationRequest != null) { + builder.field(REQUEST_FIELD.getPreferredName(), this.evaluationRequest); } builder.startArray(RATINGS_FIELD.getPreferredName()); for (RatedDocument doc : this.ratedDocs) { @@ -285,7 +322,7 @@ public class RatedRequest implements Writeable, ToXContentObject { RatedRequest other = (RatedRequest) obj; - return Objects.equals(id, other.id) && Objects.equals(testRequest, other.testRequest) + return Objects.equals(id, other.id) && Objects.equals(evaluationRequest, other.evaluationRequest) && Objects.equals(summaryFields, other.summaryFields) && Objects.equals(ratedDocs, other.ratedDocs) && Objects.equals(params, other.params) @@ -294,7 +331,7 @@ public class RatedRequest implements Writeable, ToXContentObject { @Override public final int hashCode() { - return Objects.hash(id, testRequest, summaryFields, ratedDocs, params, + return Objects.hash(id, evaluationRequest, summaryFields, ratedDocs, params, templateId); } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java index 019ae274466..e0a0b3ea133 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java @@ -52,6 +52,7 @@ import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; import static org.elasticsearch.common.xcontent.XContentHelper.createParser; +import static org.elasticsearch.index.rankeval.RatedRequest.validateEvaluatedQuery; /** * Instances of this class execute a collection of search intents (read: user @@ -99,15 +100,17 @@ public class TransportRankEvalAction extends HandledTransportAction ratedRequestsInSearch = new ArrayList<>(); for (RatedRequest ratedRequest : ratedRequests) { - SearchSourceBuilder ratedSearchSource = ratedRequest.getTestRequest(); - if (ratedSearchSource == null) { + SearchSourceBuilder evaluationRequest = ratedRequest.getEvaluationRequest(); + if (evaluationRequest == null) { Map params = ratedRequest.getParams(); String templateId = ratedRequest.getTemplateId(); TemplateScript.Factory templateScript = scriptsWithoutParams.get(templateId); String resolvedRequest = templateScript.newInstance(params).execute(); try (XContentParser subParser = createParser(namedXContentRegistry, LoggingDeprecationHandler.INSTANCE, new BytesArray(resolvedRequest), XContentType.JSON)) { - ratedSearchSource = SearchSourceBuilder.fromXContent(subParser, false); + evaluationRequest = SearchSourceBuilder.fromXContent(subParser, false); + // check for parts that should not be part of a ranking evaluation request + validateEvaluatedQuery(evaluationRequest); } catch (IOException e) { // if we fail parsing, put the exception into the errors map and continue errors.put(ratedRequest.getId(), e); @@ -116,17 +119,17 @@ public class TransportRankEvalAction extends HandledTransportAction summaryFields = ratedRequest.getSummaryFields(); if (summaryFields.isEmpty()) { - ratedSearchSource.fetchSource(false); + evaluationRequest.fetchSource(false); } else { - ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]); + evaluationRequest.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]); } - SearchRequest searchRequest = new SearchRequest(request.indices(), ratedSearchSource); + SearchRequest searchRequest = new SearchRequest(request.indices(), evaluationRequest); searchRequest.indicesOptions(request.indicesOptions()); msearchRequest.add(searchRequest); } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java index 196b50b7f61..084f29b8c9a 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RatedRequestsTests.java @@ -33,7 +33,11 @@ import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.suggest.SuggestBuilder; +import org.elasticsearch.search.suggest.SuggestBuilders; import org.elasticsearch.test.ESTestCase; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -165,7 +169,7 @@ public class RatedRequestsTests extends ESTestCase { private static RatedRequest mutateTestItem(RatedRequest original) { String id = original.getId(); - SearchSourceBuilder testRequest = original.getTestRequest(); + SearchSourceBuilder evaluationRequest = original.getEvaluationRequest(); List ratedDocs = original.getRatedDocs(); Map params = original.getParams(); List summaryFields = original.getSummaryFields(); @@ -177,11 +181,11 @@ public class RatedRequestsTests extends ESTestCase { id = randomValueOtherThan(id, () -> randomAlphaOfLength(10)); break; case 1: - if (testRequest != null) { - int size = randomValueOtherThan(testRequest.size(), () -> randomInt(Integer.MAX_VALUE)); - testRequest = new SearchSourceBuilder(); - testRequest.size(size); - testRequest.query(new MatchAllQueryBuilder()); + if (evaluationRequest != null) { + int size = randomValueOtherThan(evaluationRequest.size(), () -> randomInt(Integer.MAX_VALUE)); + evaluationRequest = new SearchSourceBuilder(); + evaluationRequest.size(size); + evaluationRequest.query(new MatchAllQueryBuilder()); } else { if (randomBoolean()) { Map mutated = new HashMap<>(); @@ -204,10 +208,10 @@ public class RatedRequestsTests extends ESTestCase { } RatedRequest ratedRequest; - if (testRequest == null) { + if (evaluationRequest == null) { ratedRequest = new RatedRequest(id, ratedDocs, params, templateId); } else { - ratedRequest = new RatedRequest(id, ratedDocs, testRequest); + ratedRequest = new RatedRequest(id, ratedDocs, evaluationRequest); } ratedRequest.addSummaryFields(summaryFields); @@ -258,6 +262,44 @@ public class RatedRequestsTests extends ESTestCase { expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, "templateId")); } + public void testAggsNotAllowed() { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + SearchSourceBuilder query = new SearchSourceBuilder(); + query.aggregation(AggregationBuilders.terms("fieldName")); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, query)); + assertEquals("Query in rated requests should not contain aggregations.", e.getMessage()); + } + + public void testSuggestionsNotAllowed() { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + SearchSourceBuilder query = new SearchSourceBuilder(); + query.suggest(new SuggestBuilder().addSuggestion("id", SuggestBuilders.completionSuggestion("fieldname"))); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, query)); + assertEquals("Query in rated requests should not contain a suggest section.", e.getMessage()); + } + + public void testHighlighterNotAllowed() { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + SearchSourceBuilder query = new SearchSourceBuilder(); + query.highlighter(new HighlightBuilder().field("field")); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, query)); + assertEquals("Query in rated requests should not contain a highlighter section.", e.getMessage()); + } + + public void testExplainNotAllowed() { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder().explain(true))); + assertEquals("Query in rated requests should not use explain.", e.getMessage()); + } + + public void testProfileNotAllowed() { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder().profile(true))); + assertEquals("Query in rated requests should not use profile.", e.getMessage()); + } + /** * test that modifying the order of index/docId to make sure it doesn't * matter for parsing xContent @@ -287,7 +329,7 @@ public class RatedRequestsTests extends ESTestCase { try (XContentParser parser = createParser(JsonXContent.jsonXContent, querySpecString)) { RatedRequest specification = RatedRequest.fromXContent(parser); assertEquals("my_qa_query", specification.getId()); - assertNotNull(specification.getTestRequest()); + assertNotNull(specification.getEvaluationRequest()); List ratedDocs = specification.getRatedDocs(); assertEquals(3, ratedDocs.size()); for (int i = 0; i < 3; i++) { diff --git a/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/index/rankeval/SmokeMultipleTemplatesIT.java b/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/index/rankeval/SmokeMultipleTemplatesIT.java index 50860ddd87b..0ad78ad0c7a 100644 --- a/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/index/rankeval/SmokeMultipleTemplatesIT.java +++ b/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/index/rankeval/SmokeMultipleTemplatesIT.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -106,6 +107,43 @@ public class SmokeMultipleTemplatesIT extends ESIntegTestCase { assertEquals(0.9, response.getEvaluationResult(), Double.MIN_VALUE); } + public void testTemplateWithAggsFails() { + String template = "{ \"aggs\" : { \"avg_grade\" : { \"avg\" : { \"field\" : \"grade\" }}}}"; + assertTemplatedRequestFailures(template, "Query in rated requests should not contain aggregations."); + } + + public void testTemplateWithSuggestFails() { + String template = "{\"suggest\" : {\"my-suggestion\" : {\"text\" : \"Elastic\",\"term\" : {\"field\" : \"message\"}}}}"; + assertTemplatedRequestFailures(template, "Query in rated requests should not contain a suggest section."); + } + + public void testTemplateWithHighlighterFails() { + String template = "{\"highlight\" : { \"fields\" : {\"content\" : {}}}}"; + assertTemplatedRequestFailures(template, "Query in rated requests should not contain a highlighter section."); + } + + public void testTemplateWithProfileFails() { + String template = "{\"profile\" : \"true\" }"; + assertTemplatedRequestFailures(template, "Query in rated requests should not use profile."); + } + + public void testTemplateWithExplainFails() { + String template = "{\"explain\" : \"true\" }"; + assertTemplatedRequestFailures(template, "Query in rated requests should not use explain."); + } + + private static void assertTemplatedRequestFailures(String template, String expectedMessage) { + List ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1)); + RatedRequest ratedRequest = new RatedRequest("id", ratedDocs, Collections.singletonMap("param1", "value1"), "templateId"); + Collection templates = Collections.singletonList(new ScriptWithId("templateId", + new Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG, template, Collections.emptyMap()))); + RankEvalSpec rankEvalSpec = new RankEvalSpec(Collections.singletonList(ratedRequest), new PrecisionAtK(), templates); + RankEvalRequest rankEvalRequest = new RankEvalRequest(rankEvalSpec, new String[] { "test" }); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> client().execute(RankEvalAction.INSTANCE, rankEvalRequest).actionGet()); + assertEquals(expectedMessage, e.getMessage()); + } + private static List createRelevant(String... docs) { List relevant = new ArrayList<>(); for (String doc : docs) {