Move indices field from RankEvalSpec to RankEvalRequest (#28341)

Currently we store the indices specified in the request URL together with all
the other ranking evaluation specification in RankEvalSpec. This is not ideal
since e.g. the indices are not rendered to xContent and so cannot be parsed
back. Instead we should keep them in RankEvalRequest.
This commit is contained in:
Christoph Büscher 2018-03-19 16:26:02 +01:00 committed by GitHub
parent 3025295f7e
commit 80532229a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 64 additions and 71 deletions

View File

@ -82,7 +82,6 @@ import java.net.URISyntaxException;
import java.nio.charset.Charset;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
@ -517,9 +516,7 @@ public final class Request {
}
static Request rankEval(RankEvalRequest rankEvalRequest) throws IOException {
// TODO maybe indices should be property of RankEvalRequest and not of the spec
List<String> indices = rankEvalRequest.getRankEvalSpec().getIndices();
String endpoint = endpoint(indices.toArray(new String[indices.size()]), Strings.EMPTY_ARRAY, "_rank_eval");
String endpoint = endpoint(rankEvalRequest.getIndices(), Strings.EMPTY_ARRAY, "_rank_eval");
HttpEntity entity = createEntity(rankEvalRequest.getRankEvalSpec(), REQUEST_BODY_CONTENT_TYPE);
return new Request(HttpGet.METHOD_NAME, endpoint, Collections.emptyMap(), entity);
}

View File

@ -71,9 +71,9 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase {
specifications.add(berlinRequest);
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
RankEvalSpec spec = new RankEvalSpec(specifications, metric);
spec.addIndices(Collections.singletonList("index"));
RankEvalResponse response = execute(new RankEvalRequest(spec), highLevelClient()::rankEval, highLevelClient()::rankEvalAsync);
RankEvalResponse response = execute(new RankEvalRequest(spec, new String[] { "index" }), highLevelClient()::rankEval,
highLevelClient()::rankEvalAsync);
// the expected Prec@ for the first query is 4/6 and the expected Prec@ for the second is 1/6, divided by 2 to get the average
double expectedPrecision = (1.0 / 6.0 + 4.0 / 6.0) / 2.0;
assertEquals(expectedPrecision, response.getEvaluationResult(), Double.MIN_VALUE);

View File

@ -104,7 +104,6 @@ import java.io.InputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@ -1109,8 +1108,7 @@ public class RequestTests extends ESTestCase {
Collections.singletonList(new RatedRequest("queryId", Collections.emptyList(), new SearchSourceBuilder())),
new PrecisionAtK());
String[] indices = randomIndicesNames(0, 5);
spec.addIndices(Arrays.asList(indices));
RankEvalRequest rankEvalRequest = new RankEvalRequest(spec);
RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, indices);
Request request = Request.rankEval(rankEvalRequest);
StringJoiner endpoint = new StringJoiner("/", "/", "");

View File

@ -19,12 +19,15 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import java.io.IOException;
import java.util.Objects;
/**
* Request to perform a search ranking evaluation.
@ -32,9 +35,11 @@ import java.io.IOException;
public class RankEvalRequest extends ActionRequest {
private RankEvalSpec rankingEvaluationSpec;
private String[] indices = Strings.EMPTY_ARRAY;
public RankEvalRequest(RankEvalSpec rankingEvaluationSpec) {
public RankEvalRequest(RankEvalSpec rankingEvaluationSpec, String[] indices) {
this.rankingEvaluationSpec = rankingEvaluationSpec;
setIndices(indices);
}
RankEvalRequest() {
@ -64,16 +69,53 @@ public class RankEvalRequest extends ActionRequest {
this.rankingEvaluationSpec = task;
}
/**
* Sets the indices the search will be executed on.
*/
public RankEvalRequest setIndices(String... indices) {
Objects.requireNonNull(indices, "indices must not be null");
for (String index : indices) {
Objects.requireNonNull(index, "index must not be null");
}
this.indices = indices;
return this;
}
/**
* @return the indices for this request
*/
public String[] getIndices() {
return indices;
}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
rankingEvaluationSpec = new RankEvalSpec(in);
if (in.getVersion().onOrAfter(Version.V_6_3_0)) {
indices = in.readStringArray();
} else {
// readStringArray uses readVInt for size, we used readInt in 6.2
int indicesSize = in.readInt();
String[] indices = new String[indicesSize];
for (int i = 0; i < indicesSize; i++) {
indices[i] = in.readString();
}
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
rankingEvaluationSpec.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_6_3_0)) {
out.writeStringArray(indices);
} else {
// writeStringArray uses writeVInt for size, we used writeInt in 6.2
out.writeInt(indices.length);
for (String index : indices) {
out.writeString(index);
}
}
}
}

View File

@ -58,8 +58,6 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
private static final int MAX_CONCURRENT_SEARCHES = 10;
/** optional: Templates to base test requests on */
private Map<String, Script> templates = new HashMap<>();
/** the indices this ranking evaluation targets */
private final List<String> indices;
public RankEvalSpec(List<RatedRequest> ratedRequests, EvaluationMetric metric, Collection<ScriptWithId> templates) {
this.metric = Objects.requireNonNull(metric, "Cannot evaluate ranking if no evaluation metric is provided.");
@ -81,7 +79,6 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
this.templates.put(idScript.id, idScript.script);
}
}
this.indices = new ArrayList<>();
}
public RankEvalSpec(List<RatedRequest> ratedRequests, EvaluationMetric metric) {
@ -102,11 +99,6 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
this.templates.put(key, value);
}
maxConcurrentSearches = in.readVInt();
int indicesSize = in.readInt();
indices = new ArrayList<>(indicesSize);
for (int i = 0; i < indicesSize; i++) {
this.indices.add(in.readString());
}
}
@Override
@ -122,10 +114,6 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
entry.getValue().writeTo(out);
}
out.writeVInt(maxConcurrentSearches);
out.writeInt(indices.size());
for (String index : indices) {
out.writeString(index);
}
}
/** Returns the metric to use for quality evaluation.*/
@ -153,14 +141,6 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
this.maxConcurrentSearches = maxConcurrentSearches;
}
public void addIndices(List<String> indices) {
this.indices.addAll(indices);
}
public List<String> getIndices() {
return Collections.unmodifiableList(indices);
}
private static final ParseField TEMPLATES_FIELD = new ParseField("templates");
private static final ParseField METRIC_FIELD = new ParseField("metric");
private static final ParseField REQUESTS_FIELD = new ParseField("requests");
@ -262,12 +242,11 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
return Objects.equals(ratedRequests, other.ratedRequests) &&
Objects.equals(metric, other.metric) &&
Objects.equals(maxConcurrentSearches, other.maxConcurrentSearches) &&
Objects.equals(templates, other.templates) &&
Objects.equals(indices, other.indices);
Objects.equals(templates, other.templates);
}
@Override
public final int hashCode() {
return Objects.hash(ratedRequests, metric, templates, maxConcurrentSearches, indices);
return Objects.hash(ratedRequests, metric, templates, maxConcurrentSearches);
}
}

View File

@ -29,8 +29,6 @@ import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import static org.elasticsearch.rest.RestRequest.Method.GET;
import static org.elasticsearch.rest.RestRequest.Method.POST;
@ -110,9 +108,8 @@ public class RestRankEvalAction extends BaseRestHandler {
}
private static void parseRankEvalRequest(RankEvalRequest rankEvalRequest, RestRequest request, XContentParser parser) {
List<String> indices = Arrays.asList(Strings.splitStringByCommaToArray(request.param("index")));
rankEvalRequest.setIndices(Strings.splitStringByCommaToArray(request.param("index")));
RankEvalSpec spec = RankEvalSpec.parse(parser);
spec.addIndices(indices);
rankEvalRequest.setRankEvalSpec(spec);
}

View File

@ -85,7 +85,6 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
@Override
protected void doExecute(RankEvalRequest request, ActionListener<RankEvalResponse> listener) {
RankEvalSpec evaluationSpecification = request.getRankEvalSpec();
List<String> indices = evaluationSpecification.getIndices();
EvaluationMetric metric = evaluationSpecification.getMetric();
List<RatedRequest> ratedRequests = evaluationSpecification.getRatedRequests();
@ -127,7 +126,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
} else {
ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]);
}
msearchRequest.add(new SearchRequest(indices.toArray(new String[indices.size()]), ratedSearchSource));
msearchRequest.add(new SearchRequest(request.getIndices(), ratedSearchSource));
}
assert ratedRequestsInSearch.size() == msearchRequest.requests().size();
client.multiSearch(msearchRequest, new RankEvalActionListener(listener, metric,

View File

@ -30,7 +30,6 @@ import org.junit.Before;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
@ -85,13 +84,12 @@ public class RankEvalRequestIT extends ESIntegTestCase {
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(),
RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request())
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request().setIndices("test"))
.actionGet();
// the expected Prec@ for the first query is 4/6 and the expected Prec@ for the
// second is 1/6, divided by 2 to get the average
@ -132,9 +130,8 @@ public class RankEvalRequestIT extends ESIntegTestCase {
// test that a different window size k affects the result
metric = new PrecisionAtK(1, false, 3);
task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
@ -165,9 +162,9 @@ public class RankEvalRequestIT extends ESIntegTestCase {
DiscountedCumulativeGain metric = new DiscountedCumulativeGain(false, null, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE,
new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
@ -176,9 +173,8 @@ public class RankEvalRequestIT extends ESIntegTestCase {
// test that a different window size k affects the result
metric = new DiscountedCumulativeGain(false, null, 3);
task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
@ -196,9 +192,9 @@ public class RankEvalRequestIT extends ESIntegTestCase {
MeanReciprocalRank metric = new MeanReciprocalRank(1, 10);
RankEvalSpec task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE,
new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
@ -211,9 +207,8 @@ public class RankEvalRequestIT extends ESIntegTestCase {
// test that a different window size k affects the result
metric = new MeanReciprocalRank(1, 3);
task = new RankEvalSpec(specifications, metric);
task.addIndices(Collections.singletonList("test"));
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
@ -229,8 +224,6 @@ public class RankEvalRequestIT extends ESIntegTestCase {
* field) will produce an error in the response
*/
public void testBadQuery() {
List<String> indices = Arrays.asList(new String[] { "test" });
List<RatedRequest> specifications = new ArrayList<>();
SearchSourceBuilder amsterdamQuery = new SearchSourceBuilder();
amsterdamQuery.query(new MatchAllQueryBuilder());
@ -245,9 +238,9 @@ public class RankEvalRequestIT extends ESIntegTestCase {
specifications.add(brokenRequest);
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());
task.addIndices(indices);
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE,
new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();

View File

@ -109,7 +109,6 @@ public class RankEvalSpecTests extends ESTestCase {
for (int i = 0; i < size; i++) {
indices.add(randomAlphaOfLengthBetween(0, 50));
}
spec.addIndices(indices);
return spec;
}
@ -117,11 +116,7 @@ public class RankEvalSpecTests extends ESTestCase {
RankEvalSpec testItem = createTestItem();
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
try (XContentParser parser = createParser(JsonXContent.jsonXContent, BytesReference.bytes(shuffled))) {
RankEvalSpec parsedItem = RankEvalSpec.parse(parser);
// indices, come from URL parameters, so they don't survive xContent roundtrip
// for the sake of being able to use equals() next, we add it to the parsed object
parsedItem.addIndices(testItem.getIndices());
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
@ -165,9 +160,8 @@ public class RankEvalSpecTests extends ESTestCase {
List<RatedRequest> ratedRequests = new ArrayList<>(original.getRatedRequests());
EvaluationMetric metric = original.getMetric();
Map<String, Script> templates = new HashMap<>(original.getTemplates());
List<String> indices = new ArrayList<>(original.getIndices());
int mutate = randomIntBetween(0, 3);
int mutate = randomIntBetween(0, 2);
switch (mutate) {
case 0:
RatedRequest request = RatedRequestsTests.createTestItem(true);
@ -183,9 +177,6 @@ public class RankEvalSpecTests extends ESTestCase {
case 2:
templates.put("mutation", new Script(ScriptType.INLINE, "mustache", randomAlphaOfLength(10), new HashMap<>()));
break;
case 3:
indices.add(randomAlphaOfLength(5));
break;
default:
throw new IllegalStateException("Requested to modify more than available parameters.");
}
@ -195,7 +186,6 @@ public class RankEvalSpecTests extends ESTestCase {
scripts.add(new ScriptWithId(entry.getKey(), entry.getValue()));
}
RankEvalSpec result = new RankEvalSpec(ratedRequests, metric, scripts);
result.addIndices(indices);
return result;
}

View File

@ -72,7 +72,6 @@ public class SmokeMultipleTemplatesIT extends ESIntegTestCase {
}
public void testPrecisionAtRequest() throws IOException {
List<String> indices = Arrays.asList(new String[] { "test" });
List<RatedRequest> specifications = new ArrayList<>();
Map<String, Object> ams_params = new HashMap<>();
@ -100,11 +99,10 @@ public class SmokeMultipleTemplatesIT extends ESIntegTestCase {
Set<ScriptWithId> templates = new HashSet<>();
templates.add(template);
RankEvalSpec task = new RankEvalSpec(specifications, metric, templates);
task.addIndices(indices);
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request().setIndices("test")).actionGet();
assertEquals(0.9, response.getEvaluationResult(), Double.MIN_VALUE);
}