RankEvalRequest should implement IndicesRequest (#29188)
Change RankEvalRequest to implement IndicesRequest, so it gets treated in a similar fashion to regular search requests e.g. by security.
This commit is contained in:
parent
d6d3fb3c73
commit
e4b30071bb
|
@ -531,7 +531,7 @@ public final class Request {
|
|||
}
|
||||
|
||||
static Request rankEval(RankEvalRequest rankEvalRequest) throws IOException {
|
||||
String endpoint = endpoint(rankEvalRequest.getIndices(), Strings.EMPTY_ARRAY, "_rank_eval");
|
||||
String endpoint = endpoint(rankEvalRequest.indices(), Strings.EMPTY_ARRAY, "_rank_eval");
|
||||
HttpEntity entity = createEntity(rankEvalRequest.getRankEvalSpec(), REQUEST_BODY_CONTENT_TYPE);
|
||||
return new Request(HttpGet.METHOD_NAME, endpoint, Collections.emptyMap(), entity);
|
||||
}
|
||||
|
|
|
@ -22,24 +22,47 @@ package org.elasticsearch.index.rankeval;
|
|||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionRequest;
|
||||
import org.elasticsearch.action.ActionRequestValidationException;
|
||||
import org.elasticsearch.action.IndicesRequest;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Request to perform a search ranking evaluation.
|
||||
*/
|
||||
public class RankEvalRequest extends ActionRequest {
|
||||
public class RankEvalRequest extends ActionRequest implements IndicesRequest.Replaceable {
|
||||
|
||||
private RankEvalSpec rankingEvaluationSpec;
|
||||
|
||||
private IndicesOptions indicesOptions = SearchRequest.DEFAULT_INDICES_OPTIONS;
|
||||
private String[] indices = Strings.EMPTY_ARRAY;
|
||||
|
||||
public RankEvalRequest(RankEvalSpec rankingEvaluationSpec, String[] indices) {
|
||||
this.rankingEvaluationSpec = rankingEvaluationSpec;
|
||||
setIndices(indices);
|
||||
this.rankingEvaluationSpec = Objects.requireNonNull(rankingEvaluationSpec, "ranking evaluation specification must not be null");
|
||||
indices(indices);
|
||||
}
|
||||
|
||||
RankEvalRequest(StreamInput in) throws IOException {
|
||||
super.readFrom(in);
|
||||
rankingEvaluationSpec = new RankEvalSpec(in);
|
||||
if (in.getVersion().onOrAfter(Version.V_6_3_0)) {
|
||||
indices = in.readStringArray();
|
||||
indicesOptions = IndicesOptions.readIndicesOptions(in);
|
||||
} else {
|
||||
// readStringArray uses readVInt for size, we used readInt in 6.2
|
||||
int indicesSize = in.readInt();
|
||||
String[] indices = new String[indicesSize];
|
||||
for (int i = 0; i < indicesSize; i++) {
|
||||
indices[i] = in.readString();
|
||||
}
|
||||
// no indices options yet
|
||||
}
|
||||
}
|
||||
|
||||
RankEvalRequest() {
|
||||
|
@ -72,7 +95,8 @@ public class RankEvalRequest extends ActionRequest {
|
|||
/**
|
||||
* Sets the indices the search will be executed on.
|
||||
*/
|
||||
public RankEvalRequest setIndices(String... indices) {
|
||||
@Override
|
||||
public RankEvalRequest indices(String... indices) {
|
||||
Objects.requireNonNull(indices, "indices must not be null");
|
||||
for (String index : indices) {
|
||||
Objects.requireNonNull(index, "index must not be null");
|
||||
|
@ -84,24 +108,23 @@ public class RankEvalRequest extends ActionRequest {
|
|||
/**
|
||||
* @return the indices for this request
|
||||
*/
|
||||
public String[] getIndices() {
|
||||
@Override
|
||||
public String[] indices() {
|
||||
return indices;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IndicesOptions indicesOptions() {
|
||||
return indicesOptions;
|
||||
}
|
||||
|
||||
public void indicesOptions(IndicesOptions indicesOptions) {
|
||||
this.indicesOptions = Objects.requireNonNull(indicesOptions, "indicesOptions must not be null");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readFrom(StreamInput in) throws IOException {
|
||||
super.readFrom(in);
|
||||
rankingEvaluationSpec = new RankEvalSpec(in);
|
||||
if (in.getVersion().onOrAfter(Version.V_6_3_0)) {
|
||||
indices = in.readStringArray();
|
||||
} else {
|
||||
// readStringArray uses readVInt for size, we used readInt in 6.2
|
||||
int indicesSize = in.readInt();
|
||||
String[] indices = new String[indicesSize];
|
||||
for (int i = 0; i < indicesSize; i++) {
|
||||
indices[i] = in.readString();
|
||||
}
|
||||
}
|
||||
throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -110,12 +133,33 @@ public class RankEvalRequest extends ActionRequest {
|
|||
rankingEvaluationSpec.writeTo(out);
|
||||
if (out.getVersion().onOrAfter(Version.V_6_3_0)) {
|
||||
out.writeStringArray(indices);
|
||||
indicesOptions.writeIndicesOptions(out);
|
||||
} else {
|
||||
// writeStringArray uses writeVInt for size, we used writeInt in 6.2
|
||||
out.writeInt(indices.length);
|
||||
for (String index : indices) {
|
||||
out.writeString(index);
|
||||
}
|
||||
// no indices options yet
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
RankEvalRequest that = (RankEvalRequest) o;
|
||||
return Objects.equals(indicesOptions, that.indicesOptions) &&
|
||||
Arrays.equals(indices, that.indices) &&
|
||||
Objects.equals(rankingEvaluationSpec, that.rankingEvaluationSpec);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(indicesOptions, Arrays.hashCode(indices), rankingEvaluationSpec);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -108,7 +108,7 @@ public class RestRankEvalAction extends BaseRestHandler {
|
|||
}
|
||||
|
||||
private static void parseRankEvalRequest(RankEvalRequest rankEvalRequest, RestRequest request, XContentParser parser) {
|
||||
rankEvalRequest.setIndices(Strings.splitStringByCommaToArray(request.param("index")));
|
||||
rankEvalRequest.indices(Strings.splitStringByCommaToArray(request.param("index")));
|
||||
RankEvalSpec spec = RankEvalSpec.parse(parser);
|
||||
rankEvalRequest.setRankEvalSpec(spec);
|
||||
}
|
||||
|
|
|
@ -75,8 +75,8 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
|||
public TransportRankEvalAction(Settings settings, ThreadPool threadPool, ActionFilters actionFilters,
|
||||
IndexNameExpressionResolver indexNameExpressionResolver, Client client, TransportService transportService,
|
||||
ScriptService scriptService, NamedXContentRegistry namedXContentRegistry) {
|
||||
super(settings, RankEvalAction.NAME, threadPool, transportService, actionFilters, indexNameExpressionResolver,
|
||||
RankEvalRequest::new);
|
||||
super(settings, RankEvalAction.NAME, threadPool, transportService, actionFilters, RankEvalRequest::new,
|
||||
indexNameExpressionResolver);
|
||||
this.scriptService = scriptService;
|
||||
this.namedXContentRegistry = namedXContentRegistry;
|
||||
this.client = client;
|
||||
|
@ -126,7 +126,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
|||
} else {
|
||||
ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]);
|
||||
}
|
||||
msearchRequest.add(new SearchRequest(request.getIndices(), ratedSearchSource));
|
||||
msearchRequest.add(new SearchRequest(request.indices(), ratedSearchSource));
|
||||
}
|
||||
assert ratedRequestsInSearch.size() == msearchRequest.requests().size();
|
||||
client.multiSearch(msearchRequest, new RankEvalActionListener(listener, metric,
|
||||
|
|
|
@ -89,7 +89,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request().setIndices("test"))
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request().indices("test"))
|
||||
.actionGet();
|
||||
// the expected Prec@ for the first query is 4/6 and the expected Prec@ for the
|
||||
// second is 1/6, divided by 2 to get the average
|
||||
|
@ -131,8 +131,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
metric = new PrecisionAtK(1, false, 3);
|
||||
task = new RankEvalSpec(specifications, metric);
|
||||
|
||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest().setIndices("test"));
|
||||
builder.setRankEvalSpec(task);
|
||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest(task, new String[] { "test" }));
|
||||
|
||||
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
// if we look only at top 3 documente, the expected P@3 for the first query is
|
||||
|
@ -164,8 +163,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE,
|
||||
new RankEvalRequest().setIndices("test"));
|
||||
builder.setRankEvalSpec(task);
|
||||
new RankEvalRequest(task, new String[] { "test" }));
|
||||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
assertEquals(DiscountedCumulativeGainTests.EXPECTED_DCG, response.getEvaluationResult(), 10E-14);
|
||||
|
@ -174,8 +172,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
metric = new DiscountedCumulativeGain(false, null, 3);
|
||||
task = new RankEvalSpec(specifications, metric);
|
||||
|
||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest().setIndices("test"));
|
||||
builder.setRankEvalSpec(task);
|
||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest(task, new String[] { "test" }));
|
||||
|
||||
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
assertEquals(12.39278926071437, response.getEvaluationResult(), 10E-14);
|
||||
|
@ -194,8 +191,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
RankEvalSpec task = new RankEvalSpec(specifications, metric);
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE,
|
||||
new RankEvalRequest().setIndices("test"));
|
||||
builder.setRankEvalSpec(task);
|
||||
new RankEvalRequest(task, new String[] { "test" }));
|
||||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
// the expected reciprocal rank for the amsterdam_query is 1/5
|
||||
|
@ -208,8 +204,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
metric = new MeanReciprocalRank(1, 3);
|
||||
task = new RankEvalSpec(specifications, metric);
|
||||
|
||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest().setIndices("test"));
|
||||
builder.setRankEvalSpec(task);
|
||||
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest(task, new String[] { "test" }));
|
||||
|
||||
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
// limiting to top 3 results, the amsterdam_query has no relevant document in it
|
||||
|
@ -240,7 +235,7 @@ public class RankEvalRequestIT extends ESIntegTestCase {
|
|||
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());
|
||||
|
||||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE,
|
||||
new RankEvalRequest().setIndices("test"));
|
||||
new RankEvalRequest(task, new String[] { "test" }));
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.rankeval;
|
||||
|
||||
import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable.Reader;
|
||||
import org.elasticsearch.common.util.ArrayUtils;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.junit.AfterClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class RankEvalRequestTests extends AbstractWireSerializingTestCase<RankEvalRequest> {
|
||||
|
||||
private static RankEvalPlugin rankEvalPlugin = new RankEvalPlugin();
|
||||
|
||||
@AfterClass
|
||||
public static void releasePluginResources() throws IOException {
|
||||
rankEvalPlugin.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(rankEvalPlugin.getNamedXContent());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(rankEvalPlugin.getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RankEvalRequest createTestInstance() {
|
||||
int numberOfIndices = randomInt(3);
|
||||
String[] indices = new String[numberOfIndices];
|
||||
for (int i=0; i < numberOfIndices; i++) {
|
||||
indices[i] = randomAlphaOfLengthBetween(5, 10);
|
||||
}
|
||||
RankEvalRequest rankEvalRequest = new RankEvalRequest(RankEvalSpecTests.createTestItem(), indices);
|
||||
IndicesOptions indicesOptions = IndicesOptions.fromOptions(
|
||||
randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean());
|
||||
rankEvalRequest.indicesOptions(indicesOptions);
|
||||
return rankEvalRequest;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Reader<RankEvalRequest> instanceReader() {
|
||||
return RankEvalRequest::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RankEvalRequest mutateInstance(RankEvalRequest instance) throws IOException {
|
||||
RankEvalRequest mutation = copyInstance(instance);
|
||||
List<Runnable> mutators = new ArrayList<>();
|
||||
mutators.add(() -> mutation.indices(ArrayUtils.concat(instance.indices(), new String[] { randomAlphaOfLength(10) })));
|
||||
mutators.add(() -> mutation.indicesOptions(randomValueOtherThan(instance.indicesOptions(),
|
||||
() -> IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()))));
|
||||
mutators.add(() -> mutation.setRankEvalSpec(RankEvalSpecTests.mutateTestItem(instance.getRankEvalSpec())));
|
||||
randomFrom(mutators).run();
|
||||
return mutation;
|
||||
}
|
||||
}
|
|
@ -70,7 +70,7 @@ public class RankEvalSpecTests extends ESTestCase {
|
|||
return result;
|
||||
}
|
||||
|
||||
private static RankEvalSpec createTestItem() throws IOException {
|
||||
static RankEvalSpec createTestItem() {
|
||||
Supplier<EvaluationMetric> metric = randomFrom(Arrays.asList(
|
||||
() -> PrecisionAtKTests.createTestItem(),
|
||||
() -> MeanReciprocalRankTests.createTestItem(),
|
||||
|
@ -87,6 +87,9 @@ public class RankEvalSpecTests extends ESTestCase {
|
|||
builder.field("field", randomAlphaOfLengthBetween(1, 5));
|
||||
builder.endObject();
|
||||
script = Strings.toString(builder);
|
||||
} catch (IOException e) {
|
||||
// this shouldn't happen in tests, re-throw just not to swallow it
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
templates = new HashSet<>();
|
||||
|
@ -156,7 +159,7 @@ public class RankEvalSpecTests extends ESTestCase {
|
|||
checkEqualsAndHashCode(createTestItem(), RankEvalSpecTests::copy, RankEvalSpecTests::mutateTestItem);
|
||||
}
|
||||
|
||||
private static RankEvalSpec mutateTestItem(RankEvalSpec original) {
|
||||
static RankEvalSpec mutateTestItem(RankEvalSpec original) {
|
||||
List<RatedRequest> ratedRequests = new ArrayList<>(original.getRatedRequests());
|
||||
EvaluationMetric metric = original.getMetric();
|
||||
Map<String, Script> templates = new HashMap<>(original.getTemplates());
|
||||
|
|
|
@ -102,7 +102,7 @@ public class SmokeMultipleTemplatesIT extends ESIntegTestCase {
|
|||
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
|
||||
builder.setRankEvalSpec(task);
|
||||
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request().setIndices("test")).actionGet();
|
||||
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request().indices("test")).actionGet();
|
||||
assertEquals(0.9, response.getEvaluationResult(), Double.MIN_VALUE);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue