diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java index bae54425017..d8fd26c1152 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalSpec.java @@ -23,23 +23,17 @@ import org.elasticsearch.action.support.ToXContentToBytes; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParsingException; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.script.Script; -import org.elasticsearch.script.ScriptContext; -import org.elasticsearch.search.builder.SearchSourceBuilder; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; -import java.util.Map; import java.util.Objects; /** @@ -154,24 +148,7 @@ public class RankEvalSpec extends ToXContentToBytes implements Writeable { } public static RankEvalSpec parse(XContentParser parser, RankEvalContext context, boolean templated) throws IOException { - RankEvalSpec spec = PARSER.parse(parser, context); - - if (templated) { - for (RatedRequest query_spec : spec.getSpecifications()) { - Map params = query_spec.getParams(); - Script scriptWithParams = new Script(spec.template.getType(), spec.template.getLang(), spec.template.getIdOrCode(), params); - String resolvedRequest = ((BytesReference) (context.getScriptService() - .executable(scriptWithParams, ScriptContext.Standard.SEARCH).run())).utf8ToString(); - try (XContentParser subParser = XContentFactory.xContent(resolvedRequest).createParser(resolvedRequest)) { - QueryParseContext parseContext = new QueryParseContext(context.getSearchRequestParsers().queryParsers, subParser, - context.getParseFieldMatcher()); - SearchSourceBuilder templateResult = SearchSourceBuilder.fromXContent(parseContext, context.getAggs(), - context.getSuggesters(), context.getSearchExtParsers()); - query_spec.setTestRequest(templateResult); - } - } - } - return spec; + return PARSER.parse(parser, context); } @Override diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java index 6dc53244a6d..2b05e1a4976 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RatedRequest.java @@ -74,7 +74,8 @@ public class RatedRequest extends ToXContentToBytes implements Writeable { public RatedRequest(StreamInput in) throws IOException { this.specId = in.readString(); - testRequest = new SearchSourceBuilder(in); + testRequest = in.readOptionalWriteable(SearchSourceBuilder::new); + int indicesSize = in.readInt(); indices = new ArrayList<>(indicesSize); for (int i = 0; i < indicesSize; i++) { @@ -101,7 +102,8 @@ public class RatedRequest extends ToXContentToBytes implements Writeable { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(specId); - testRequest.writeTo(out); + out.writeOptionalWriteable(testRequest); + out.writeInt(indices.size()); for (String index : indices) { out.writeString(index); @@ -255,8 +257,9 @@ public class RatedRequest extends ToXContentToBytes implements Writeable { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(ID_FIELD.getPreferredName(), this.specId); - if (testRequest != null) + if (testRequest != null) { builder.field(REQUEST_FIELD.getPreferredName(), this.testRequest); + } builder.startObject(PARAMS_FIELD.getPreferredName()); for (Entry entry : this.params.entrySet()) { builder.field(entry.getKey(), entry.getValue()); diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java index 9577fde5f31..7d73cf5d9ca 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java @@ -26,14 +26,24 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryParseContext; +import org.elasticsearch.script.CompiledScript; +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchRequestParsers; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import java.io.IOException; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -51,12 +61,17 @@ import java.util.concurrent.atomic.AtomicInteger; * */ public class TransportRankEvalAction extends HandledTransportAction { private Client client; - + private ScriptService scriptService; + private SearchRequestParsers searchRequestParsers; + @Inject public TransportRankEvalAction(Settings settings, ThreadPool threadPool, ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver, Client client, TransportService transportService) { + IndexNameExpressionResolver indexNameExpressionResolver, Client client, TransportService transportService, + SearchRequestParsers searchRequestParsers, ScriptService scriptService) { super(settings, RankEvalAction.NAME, threadPool, transportService, actionFilters, indexNameExpressionResolver, RankEvalRequest::new); + this.searchRequestParsers = searchRequestParsers; + this.scriptService = scriptService; this.client = client; } @@ -69,10 +84,25 @@ public class TransportRankEvalAction extends HandledTransportAction partialResults = new ConcurrentHashMap<>(specifications.size()); Map errors = new ConcurrentHashMap<>(specifications.size()); + CompiledScript scriptWithoutParams = null; + if (qualityTask.getTemplate() != null) { + scriptWithoutParams = scriptService.compile(qualityTask.getTemplate(), ScriptContext.Standard.SEARCH, new HashMap<>()); + } for (RatedRequest querySpecification : specifications) { final RankEvalActionListener searchListener = new RankEvalActionListener(listener, qualityTask.getMetric(), querySpecification, partialResults, errors, responseCounter); SearchSourceBuilder specRequest = querySpecification.getTestRequest(); + if (specRequest == null) { + Map params = querySpecification.getParams(); + String resolvedRequest = ((BytesReference) (scriptService.executable(scriptWithoutParams, params).run())).utf8ToString(); + try (XContentParser subParser = XContentFactory.xContent(resolvedRequest).createParser(resolvedRequest)) { + QueryParseContext parseContext = new QueryParseContext(searchRequestParsers.queryParsers, subParser, parseFieldMatcher); + specRequest = SearchSourceBuilder.fromXContent(parseContext, searchRequestParsers.aggParsers, + searchRequestParsers.suggesters, searchRequestParsers.searchExtParsers); + } catch (IOException e) { + listener.onFailure(e); + } + } List summaryFields = querySpecification.getSummaryFields(); if (summaryFields.isEmpty()) { specRequest.fetchSource(false); diff --git a/qa/smoke-test-rank-eval-with-mustache/build.gradle b/qa/smoke-test-rank-eval-with-mustache/build.gradle index 4860d5469af..4fdbaa04502 100644 --- a/qa/smoke-test-rank-eval-with-mustache/build.gradle +++ b/qa/smoke-test-rank-eval-with-mustache/build.gradle @@ -19,9 +19,9 @@ apply plugin: 'elasticsearch.rest-test' -/* + dependencies { testCompile project(path: ':modules:rank-eval', configuration: 'runtime') testCompile project(path: ':modules:lang-mustache', configuration: 'runtime') } -*/ + diff --git a/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/smoketest/SmokeMultipleTemplatesIT.java b/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/smoketest/SmokeMultipleTemplatesIT.java new file mode 100644 index 00000000000..02213642d39 --- /dev/null +++ b/qa/smoke-test-rank-eval-with-mustache/src/test/java/org/elasticsearch/smoketest/SmokeMultipleTemplatesIT.java @@ -0,0 +1,122 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.smoketest; + +import org.elasticsearch.index.rankeval.Precision; +import org.elasticsearch.index.rankeval.RankEvalAction; +import org.elasticsearch.index.rankeval.RankEvalPlugin; +import org.elasticsearch.index.rankeval.RankEvalRequest; +import org.elasticsearch.index.rankeval.RankEvalRequestBuilder; +import org.elasticsearch.index.rankeval.RankEvalResponse; +import org.elasticsearch.index.rankeval.RankEvalSpec; +import org.elasticsearch.index.rankeval.RatedDocument; +import org.elasticsearch.index.rankeval.RatedRequest; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.test.ESIntegTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + + +public class SmokeMultipleTemplatesIT extends ESIntegTestCase { + + @Override + protected Collection> transportClientPlugins() { + return Arrays.asList(RankEvalPlugin.class); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(RankEvalPlugin.class); + } + + @Before + public void setup() { + createIndex("test"); + ensureGreen(); + + client().prepareIndex("test", "testtype").setId("1") + .setSource("text", "berlin", "title", "Berlin, Germany").get(); + client().prepareIndex("test", "testtype").setId("2") + .setSource("text", "amsterdam").get(); + client().prepareIndex("test", "testtype").setId("3") + .setSource("text", "amsterdam").get(); + client().prepareIndex("test", "testtype").setId("4") + .setSource("text", "amsterdam").get(); + client().prepareIndex("test", "testtype").setId("5") + .setSource("text", "amsterdam").get(); + client().prepareIndex("test", "testtype").setId("6") + .setSource("text", "amsterdam").get(); + refresh(); + } + + public void testPrecisionAtRequest() throws IOException { + List indices = Arrays.asList(new String[] { "test" }); + List types = Arrays.asList(new String[] { "testtype" }); + + List specifications = new ArrayList<>(); + RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", null, indices, types, createRelevant("2", "3", "4", "5")); + Map ams_params = new HashMap<>(); + ams_params.put("querystring", "amsterdam"); + amsterdamRequest.setParams(ams_params); + specifications.add(amsterdamRequest); + + RatedRequest berlinRequest = new RatedRequest("berlin_query", null, indices, types, createRelevant("1")); + Map berlin_params = new HashMap<>(); + berlin_params.put("querystring", "berlin"); + berlinRequest.setParams(berlin_params); + specifications.add(berlinRequest); + + Precision metric = new Precision(); + RankEvalSpec task = new RankEvalSpec(specifications, metric); + task.setTemplate( + new Script( + ScriptType.INLINE, + "mustache", "{\"query\": {\"match\": {\"text\": \"{{querystring}}\"}}}", + new HashMap<>())); + + RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest()); + builder.setRankEvalSpec(task); + + RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet(); + assertEquals(0.9, response.getQualityLevel(), Double.MIN_VALUE); + } + + private static List createRelevant(String... docs) { + List relevant = new ArrayList<>(); + for (String doc : docs) { + relevant.add(new RatedDocument("test", "testtype", doc, Rating.RELEVANT.ordinal())); + } + return relevant; + } + + public enum Rating { + IRRELEVANT, RELEVANT; + } + + }