Use msearch instead of single search (#27520)

Change TransportRankEvalAction to use one MultiSearchRequest instead of issuing several parallel search requests to simplify the transport action.
This commit is contained in:
Christoph Büscher 2017-11-27 10:15:59 +01:00 committed by GitHub
parent 5661b1c3df
commit 1352b7c6ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 40 additions and 66 deletions

View File

@ -20,8 +20,10 @@
package org.elasticsearch.index.rankeval; package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.MultiSearchRequest;
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.MultiSearchResponse.Item;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
@ -41,15 +43,12 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import java.io.IOException; import java.io.IOException;
import java.util.Collection; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.common.xcontent.XContentHelper.createParser; import static org.elasticsearch.common.xcontent.XContentHelper.createParser;
@ -69,7 +68,6 @@ import static org.elasticsearch.common.xcontent.XContentHelper.createParser;
public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequest, RankEvalResponse> { public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequest, RankEvalResponse> {
private Client client; private Client client;
private ScriptService scriptService; private ScriptService scriptService;
Queue<RequestTask> taskQueue = new ConcurrentLinkedQueue<>();
private NamedXContentRegistry namedXContentRegistry; private NamedXContentRegistry namedXContentRegistry;
@Inject @Inject
@ -88,10 +86,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
RankEvalSpec evaluationSpecification = request.getRankEvalSpec(); RankEvalSpec evaluationSpecification = request.getRankEvalSpec();
List<String> indices = evaluationSpecification.getIndices(); List<String> indices = evaluationSpecification.getIndices();
Collection<RatedRequest> ratedRequests = evaluationSpecification.getRatedRequests(); List<RatedRequest> ratedRequests = evaluationSpecification.getRatedRequests();
AtomicInteger responseCounter = new AtomicInteger(ratedRequests.size());
Map<String, EvalQueryQuality> partialResults = new ConcurrentHashMap<>(
ratedRequests.size());
Map<String, Exception> errors = new ConcurrentHashMap<>(ratedRequests.size()); Map<String, Exception> errors = new ConcurrentHashMap<>(ratedRequests.size());
Map<String, TemplateScript.Factory> scriptsWithoutParams = new HashMap<>(); Map<String, TemplateScript.Factory> scriptsWithoutParams = new HashMap<>();
@ -99,6 +94,9 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
scriptsWithoutParams.put(entry.getKey(), scriptService.compile(entry.getValue(), TemplateScript.CONTEXT)); scriptsWithoutParams.put(entry.getKey(), scriptService.compile(entry.getValue(), TemplateScript.CONTEXT));
} }
MultiSearchRequest msearchRequest = new MultiSearchRequest();
msearchRequest.maxConcurrentSearchRequests(evaluationSpecification.getMaxConcurrentSearches());
List<RatedRequest> ratedRequestsInSearch = new ArrayList<>();
for (RatedRequest ratedRequest : ratedRequests) { for (RatedRequest ratedRequest : ratedRequests) {
SearchSourceBuilder ratedSearchSource = ratedRequest.getTestRequest(); SearchSourceBuilder ratedSearchSource = ratedRequest.getTestRequest();
if (ratedSearchSource == null) { if (ratedSearchSource == null) {
@ -109,87 +107,63 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
try (XContentParser subParser = createParser(namedXContentRegistry, new BytesArray(resolvedRequest), XContentType.JSON)) { try (XContentParser subParser = createParser(namedXContentRegistry, new BytesArray(resolvedRequest), XContentType.JSON)) {
ratedSearchSource = SearchSourceBuilder.fromXContent(subParser); ratedSearchSource = SearchSourceBuilder.fromXContent(subParser);
} catch (IOException e) { } catch (IOException e) {
listener.onFailure(e); // if we fail parsing, put the exception into the errors map and continue
errors.put(ratedRequest.getId(), e);
continue;
} }
} }
ratedRequestsInSearch.add(ratedRequest);
List<String> summaryFields = ratedRequest.getSummaryFields(); List<String> summaryFields = ratedRequest.getSummaryFields();
if (summaryFields.isEmpty()) { if (summaryFields.isEmpty()) {
ratedSearchSource.fetchSource(false); ratedSearchSource.fetchSource(false);
} else { } else {
ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]); ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]);
} }
msearchRequest.add(new SearchRequest(indices.toArray(new String[indices.size()]), ratedSearchSource));
SearchRequest templatedRequest = new SearchRequest(indices.toArray(new String[indices.size()]), ratedSearchSource);
final RankEvalActionListener searchListener = new RankEvalActionListener(listener,
evaluationSpecification.getMetric(), ratedRequest, partialResults, errors, responseCounter);
RequestTask task = new RequestTask(templatedRequest, searchListener);
taskQueue.add(task);
}
// Execute top n tasks, further execution is triggered in RankEvalActionListener
for (int i = 0; (i < Math.min(ratedRequests.size(),
evaluationSpecification.getMaxConcurrentSearches())); i++) {
RequestTask task = taskQueue.poll();
client.search(task.request, task.searchListener);
} }
assert ratedRequestsInSearch.size() == msearchRequest.requests().size();
client.multiSearch(msearchRequest, new RankEvalActionListener(listener, evaluationSpecification.getMetric(),
ratedRequestsInSearch.toArray(new RatedRequest[ratedRequestsInSearch.size()]), errors));
} }
private class RequestTask { class RankEvalActionListener implements ActionListener<MultiSearchResponse> {
private SearchRequest request;
private RankEvalActionListener searchListener;
RequestTask(SearchRequest request, RankEvalActionListener listener) { private final ActionListener<RankEvalResponse> listener;
this.request = request; private final RatedRequest[] specifications;
this.searchListener = listener;
}
}
class RankEvalActionListener implements ActionListener<SearchResponse> { private final Map<String, Exception> errors;
private final EvaluationMetric metric;
private ActionListener<RankEvalResponse> listener; RankEvalActionListener(ActionListener<RankEvalResponse> listener, EvaluationMetric metric, RatedRequest[] specifications,
private RatedRequest specification; Map<String, Exception> errors) {
private Map<String, EvalQueryQuality> requestDetails;
private Map<String, Exception> errors;
private EvaluationMetric metric;
private AtomicInteger responseCounter;
RankEvalActionListener(ActionListener<RankEvalResponse> listener,
EvaluationMetric metric, RatedRequest specification,
Map<String, EvalQueryQuality> details, Map<String, Exception> errors,
AtomicInteger responseCounter) {
this.listener = listener; this.listener = listener;
this.metric = metric; this.metric = metric;
this.errors = errors; this.errors = errors;
this.specification = specification; this.specifications = specifications;
this.requestDetails = details;
this.responseCounter = responseCounter;
} }
@Override @Override
public void onResponse(SearchResponse searchResponse) { public void onResponse(MultiSearchResponse multiSearchResponse) {
SearchHit[] hits = searchResponse.getHits().getHits(); int responsePosition = 0;
EvalQueryQuality queryQuality = metric.evaluate(specification.getId(), hits, Map<String, EvalQueryQuality> responseDetails = new HashMap<>(specifications.length);
specification.getRatedDocs()); for (Item response : multiSearchResponse.getResponses()) {
requestDetails.put(specification.getId(), queryQuality); RatedRequest specification = specifications[responsePosition];
handleResponse(); if (response.isFailure() == false) {
SearchHit[] hits = response.getResponse().getHits().getHits();
EvalQueryQuality queryQuality = this.metric.evaluate(specification.getId(), hits, specification.getRatedDocs());
responseDetails.put(specification.getId(), queryQuality);
} else {
errors.put(specification.getId(), response.getFailure());
}
responsePosition++;
}
listener.onResponse(new RankEvalResponse(this.metric.combine(responseDetails.values()), responseDetails, this.errors));
} }
@Override @Override
public void onFailure(Exception exception) { public void onFailure(Exception exception) {
errors.put(specification.getId(), exception); listener.onFailure(exception);
handleResponse();
}
private void handleResponse() {
if (responseCounter.decrementAndGet() == 0) {
listener.onResponse(new RankEvalResponse(metric.combine(requestDetails.values()), requestDetails, errors));
} else {
if (!taskQueue.isEmpty()) {
RequestTask task = taskQueue.poll();
client.search(task.request, task.searchListener);
}
}
} }
} }
} }