Use msearch instead of single search (#27520)

Change TransportRankEvalAction to use one MultiSearchRequest instead of issuing several parallel search requests to simplify the transport action.
This commit is contained in:
Christoph Büscher 2017-11-27 10:15:59 +01:00 committed by GitHub
parent 5661b1c3df
commit 1352b7c6ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 40 additions and 66 deletions

View File

@ -20,8 +20,10 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.MultiSearchRequest;
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.MultiSearchResponse.Item;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
@ -41,15 +43,12 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import java.io.IOException;
import java.util.Collection;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.common.xcontent.XContentHelper.createParser;
@ -69,7 +68,6 @@ import static org.elasticsearch.common.xcontent.XContentHelper.createParser;
public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequest, RankEvalResponse> {
private Client client;
private ScriptService scriptService;
Queue<RequestTask> taskQueue = new ConcurrentLinkedQueue<>();
private NamedXContentRegistry namedXContentRegistry;
@Inject
@ -88,10 +86,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
RankEvalSpec evaluationSpecification = request.getRankEvalSpec();
List<String> indices = evaluationSpecification.getIndices();
Collection<RatedRequest> ratedRequests = evaluationSpecification.getRatedRequests();
AtomicInteger responseCounter = new AtomicInteger(ratedRequests.size());
Map<String, EvalQueryQuality> partialResults = new ConcurrentHashMap<>(
ratedRequests.size());
List<RatedRequest> ratedRequests = evaluationSpecification.getRatedRequests();
Map<String, Exception> errors = new ConcurrentHashMap<>(ratedRequests.size());
Map<String, TemplateScript.Factory> scriptsWithoutParams = new HashMap<>();
@ -99,6 +94,9 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
scriptsWithoutParams.put(entry.getKey(), scriptService.compile(entry.getValue(), TemplateScript.CONTEXT));
}
MultiSearchRequest msearchRequest = new MultiSearchRequest();
msearchRequest.maxConcurrentSearchRequests(evaluationSpecification.getMaxConcurrentSearches());
List<RatedRequest> ratedRequestsInSearch = new ArrayList<>();
for (RatedRequest ratedRequest : ratedRequests) {
SearchSourceBuilder ratedSearchSource = ratedRequest.getTestRequest();
if (ratedSearchSource == null) {
@ -109,87 +107,63 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
try (XContentParser subParser = createParser(namedXContentRegistry, new BytesArray(resolvedRequest), XContentType.JSON)) {
ratedSearchSource = SearchSourceBuilder.fromXContent(subParser);
} catch (IOException e) {
listener.onFailure(e);
// if we fail parsing, put the exception into the errors map and continue
errors.put(ratedRequest.getId(), e);
continue;
}
}
ratedRequestsInSearch.add(ratedRequest);
List<String> summaryFields = ratedRequest.getSummaryFields();
if (summaryFields.isEmpty()) {
ratedSearchSource.fetchSource(false);
} else {
ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]);
}
SearchRequest templatedRequest = new SearchRequest(indices.toArray(new String[indices.size()]), ratedSearchSource);
final RankEvalActionListener searchListener = new RankEvalActionListener(listener,
evaluationSpecification.getMetric(), ratedRequest, partialResults, errors, responseCounter);
RequestTask task = new RequestTask(templatedRequest, searchListener);
taskQueue.add(task);
}
// Execute top n tasks, further execution is triggered in RankEvalActionListener
for (int i = 0; (i < Math.min(ratedRequests.size(),
evaluationSpecification.getMaxConcurrentSearches())); i++) {
RequestTask task = taskQueue.poll();
client.search(task.request, task.searchListener);
msearchRequest.add(new SearchRequest(indices.toArray(new String[indices.size()]), ratedSearchSource));
}
assert ratedRequestsInSearch.size() == msearchRequest.requests().size();
client.multiSearch(msearchRequest, new RankEvalActionListener(listener, evaluationSpecification.getMetric(),
ratedRequestsInSearch.toArray(new RatedRequest[ratedRequestsInSearch.size()]), errors));
}
private class RequestTask {
private SearchRequest request;
private RankEvalActionListener searchListener;
class RankEvalActionListener implements ActionListener<MultiSearchResponse> {
RequestTask(SearchRequest request, RankEvalActionListener listener) {
this.request = request;
this.searchListener = listener;
}
}
private final ActionListener<RankEvalResponse> listener;
private final RatedRequest[] specifications;
class RankEvalActionListener implements ActionListener<SearchResponse> {
private final Map<String, Exception> errors;
private final EvaluationMetric metric;
private ActionListener<RankEvalResponse> listener;
private RatedRequest specification;
private Map<String, EvalQueryQuality> requestDetails;
private Map<String, Exception> errors;
private EvaluationMetric metric;
private AtomicInteger responseCounter;
RankEvalActionListener(ActionListener<RankEvalResponse> listener,
EvaluationMetric metric, RatedRequest specification,
Map<String, EvalQueryQuality> details, Map<String, Exception> errors,
AtomicInteger responseCounter) {
RankEvalActionListener(ActionListener<RankEvalResponse> listener, EvaluationMetric metric, RatedRequest[] specifications,
Map<String, Exception> errors) {
this.listener = listener;
this.metric = metric;
this.errors = errors;
this.specification = specification;
this.requestDetails = details;
this.responseCounter = responseCounter;
this.specifications = specifications;
}
@Override
public void onResponse(SearchResponse searchResponse) {
SearchHit[] hits = searchResponse.getHits().getHits();
EvalQueryQuality queryQuality = metric.evaluate(specification.getId(), hits,
specification.getRatedDocs());
requestDetails.put(specification.getId(), queryQuality);
handleResponse();
public void onResponse(MultiSearchResponse multiSearchResponse) {
int responsePosition = 0;
Map<String, EvalQueryQuality> responseDetails = new HashMap<>(specifications.length);
for (Item response : multiSearchResponse.getResponses()) {
RatedRequest specification = specifications[responsePosition];
if (response.isFailure() == false) {
SearchHit[] hits = response.getResponse().getHits().getHits();
EvalQueryQuality queryQuality = this.metric.evaluate(specification.getId(), hits, specification.getRatedDocs());
responseDetails.put(specification.getId(), queryQuality);
} else {
errors.put(specification.getId(), response.getFailure());
}
responsePosition++;
}
listener.onResponse(new RankEvalResponse(this.metric.combine(responseDetails.values()), responseDetails, this.errors));
}
@Override
public void onFailure(Exception exception) {
errors.put(specification.getId(), exception);
handleResponse();
}
private void handleResponse() {
if (responseCounter.decrementAndGet() == 0) {
listener.onResponse(new RankEvalResponse(metric.combine(requestDetails.values()), requestDetails, errors));
} else {
if (!taskQueue.isEmpty()) {
RequestTask task = taskQueue.poll();
client.search(task.request, task.searchListener);
}
}
listener.onFailure(exception);
}
}
}