Use msearch instead of single search (#27520)
Change TransportRankEvalAction to use one MultiSearchRequest instead of issuing several parallel search requests to simplify the transport action.
This commit is contained in:
parent
5661b1c3df
commit
1352b7c6ea
|
@ -20,8 +20,10 @@
|
||||||
package org.elasticsearch.index.rankeval;
|
package org.elasticsearch.index.rankeval;
|
||||||
|
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.action.search.MultiSearchRequest;
|
||||||
|
import org.elasticsearch.action.search.MultiSearchResponse;
|
||||||
|
import org.elasticsearch.action.search.MultiSearchResponse.Item;
|
||||||
import org.elasticsearch.action.search.SearchRequest;
|
import org.elasticsearch.action.search.SearchRequest;
|
||||||
import org.elasticsearch.action.search.SearchResponse;
|
|
||||||
import org.elasticsearch.action.support.ActionFilters;
|
import org.elasticsearch.action.support.ActionFilters;
|
||||||
import org.elasticsearch.action.support.HandledTransportAction;
|
import org.elasticsearch.action.support.HandledTransportAction;
|
||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
|
@ -41,15 +43,12 @@ import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.transport.TransportService;
|
import org.elasticsearch.transport.TransportService;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collection;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Map.Entry;
|
import java.util.Map.Entry;
|
||||||
import java.util.Queue;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
import static org.elasticsearch.common.xcontent.XContentHelper.createParser;
|
import static org.elasticsearch.common.xcontent.XContentHelper.createParser;
|
||||||
|
|
||||||
|
@ -69,7 +68,6 @@ import static org.elasticsearch.common.xcontent.XContentHelper.createParser;
|
||||||
public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequest, RankEvalResponse> {
|
public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequest, RankEvalResponse> {
|
||||||
private Client client;
|
private Client client;
|
||||||
private ScriptService scriptService;
|
private ScriptService scriptService;
|
||||||
Queue<RequestTask> taskQueue = new ConcurrentLinkedQueue<>();
|
|
||||||
private NamedXContentRegistry namedXContentRegistry;
|
private NamedXContentRegistry namedXContentRegistry;
|
||||||
|
|
||||||
@Inject
|
@Inject
|
||||||
|
@ -88,10 +86,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
||||||
RankEvalSpec evaluationSpecification = request.getRankEvalSpec();
|
RankEvalSpec evaluationSpecification = request.getRankEvalSpec();
|
||||||
List<String> indices = evaluationSpecification.getIndices();
|
List<String> indices = evaluationSpecification.getIndices();
|
||||||
|
|
||||||
Collection<RatedRequest> ratedRequests = evaluationSpecification.getRatedRequests();
|
List<RatedRequest> ratedRequests = evaluationSpecification.getRatedRequests();
|
||||||
AtomicInteger responseCounter = new AtomicInteger(ratedRequests.size());
|
|
||||||
Map<String, EvalQueryQuality> partialResults = new ConcurrentHashMap<>(
|
|
||||||
ratedRequests.size());
|
|
||||||
Map<String, Exception> errors = new ConcurrentHashMap<>(ratedRequests.size());
|
Map<String, Exception> errors = new ConcurrentHashMap<>(ratedRequests.size());
|
||||||
|
|
||||||
Map<String, TemplateScript.Factory> scriptsWithoutParams = new HashMap<>();
|
Map<String, TemplateScript.Factory> scriptsWithoutParams = new HashMap<>();
|
||||||
|
@ -99,6 +94,9 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
||||||
scriptsWithoutParams.put(entry.getKey(), scriptService.compile(entry.getValue(), TemplateScript.CONTEXT));
|
scriptsWithoutParams.put(entry.getKey(), scriptService.compile(entry.getValue(), TemplateScript.CONTEXT));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MultiSearchRequest msearchRequest = new MultiSearchRequest();
|
||||||
|
msearchRequest.maxConcurrentSearchRequests(evaluationSpecification.getMaxConcurrentSearches());
|
||||||
|
List<RatedRequest> ratedRequestsInSearch = new ArrayList<>();
|
||||||
for (RatedRequest ratedRequest : ratedRequests) {
|
for (RatedRequest ratedRequest : ratedRequests) {
|
||||||
SearchSourceBuilder ratedSearchSource = ratedRequest.getTestRequest();
|
SearchSourceBuilder ratedSearchSource = ratedRequest.getTestRequest();
|
||||||
if (ratedSearchSource == null) {
|
if (ratedSearchSource == null) {
|
||||||
|
@ -109,87 +107,63 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
|
||||||
try (XContentParser subParser = createParser(namedXContentRegistry, new BytesArray(resolvedRequest), XContentType.JSON)) {
|
try (XContentParser subParser = createParser(namedXContentRegistry, new BytesArray(resolvedRequest), XContentType.JSON)) {
|
||||||
ratedSearchSource = SearchSourceBuilder.fromXContent(subParser);
|
ratedSearchSource = SearchSourceBuilder.fromXContent(subParser);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
listener.onFailure(e);
|
// if we fail parsing, put the exception into the errors map and continue
|
||||||
|
errors.put(ratedRequest.getId(), e);
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ratedRequestsInSearch.add(ratedRequest);
|
||||||
List<String> summaryFields = ratedRequest.getSummaryFields();
|
List<String> summaryFields = ratedRequest.getSummaryFields();
|
||||||
if (summaryFields.isEmpty()) {
|
if (summaryFields.isEmpty()) {
|
||||||
ratedSearchSource.fetchSource(false);
|
ratedSearchSource.fetchSource(false);
|
||||||
} else {
|
} else {
|
||||||
ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]);
|
ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]);
|
||||||
}
|
}
|
||||||
|
msearchRequest.add(new SearchRequest(indices.toArray(new String[indices.size()]), ratedSearchSource));
|
||||||
SearchRequest templatedRequest = new SearchRequest(indices.toArray(new String[indices.size()]), ratedSearchSource);
|
|
||||||
final RankEvalActionListener searchListener = new RankEvalActionListener(listener,
|
|
||||||
evaluationSpecification.getMetric(), ratedRequest, partialResults, errors, responseCounter);
|
|
||||||
RequestTask task = new RequestTask(templatedRequest, searchListener);
|
|
||||||
taskQueue.add(task);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute top n tasks, further execution is triggered in RankEvalActionListener
|
|
||||||
for (int i = 0; (i < Math.min(ratedRequests.size(),
|
|
||||||
evaluationSpecification.getMaxConcurrentSearches())); i++) {
|
|
||||||
RequestTask task = taskQueue.poll();
|
|
||||||
client.search(task.request, task.searchListener);
|
|
||||||
}
|
}
|
||||||
|
assert ratedRequestsInSearch.size() == msearchRequest.requests().size();
|
||||||
|
client.multiSearch(msearchRequest, new RankEvalActionListener(listener, evaluationSpecification.getMetric(),
|
||||||
|
ratedRequestsInSearch.toArray(new RatedRequest[ratedRequestsInSearch.size()]), errors));
|
||||||
}
|
}
|
||||||
|
|
||||||
private class RequestTask {
|
class RankEvalActionListener implements ActionListener<MultiSearchResponse> {
|
||||||
private SearchRequest request;
|
|
||||||
private RankEvalActionListener searchListener;
|
|
||||||
|
|
||||||
RequestTask(SearchRequest request, RankEvalActionListener listener) {
|
private final ActionListener<RankEvalResponse> listener;
|
||||||
this.request = request;
|
private final RatedRequest[] specifications;
|
||||||
this.searchListener = listener;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class RankEvalActionListener implements ActionListener<SearchResponse> {
|
private final Map<String, Exception> errors;
|
||||||
|
private final EvaluationMetric metric;
|
||||||
|
|
||||||
private ActionListener<RankEvalResponse> listener;
|
RankEvalActionListener(ActionListener<RankEvalResponse> listener, EvaluationMetric metric, RatedRequest[] specifications,
|
||||||
private RatedRequest specification;
|
Map<String, Exception> errors) {
|
||||||
private Map<String, EvalQueryQuality> requestDetails;
|
|
||||||
private Map<String, Exception> errors;
|
|
||||||
private EvaluationMetric metric;
|
|
||||||
private AtomicInteger responseCounter;
|
|
||||||
|
|
||||||
RankEvalActionListener(ActionListener<RankEvalResponse> listener,
|
|
||||||
EvaluationMetric metric, RatedRequest specification,
|
|
||||||
Map<String, EvalQueryQuality> details, Map<String, Exception> errors,
|
|
||||||
AtomicInteger responseCounter) {
|
|
||||||
this.listener = listener;
|
this.listener = listener;
|
||||||
this.metric = metric;
|
this.metric = metric;
|
||||||
this.errors = errors;
|
this.errors = errors;
|
||||||
this.specification = specification;
|
this.specifications = specifications;
|
||||||
this.requestDetails = details;
|
|
||||||
this.responseCounter = responseCounter;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onResponse(SearchResponse searchResponse) {
|
public void onResponse(MultiSearchResponse multiSearchResponse) {
|
||||||
SearchHit[] hits = searchResponse.getHits().getHits();
|
int responsePosition = 0;
|
||||||
EvalQueryQuality queryQuality = metric.evaluate(specification.getId(), hits,
|
Map<String, EvalQueryQuality> responseDetails = new HashMap<>(specifications.length);
|
||||||
specification.getRatedDocs());
|
for (Item response : multiSearchResponse.getResponses()) {
|
||||||
requestDetails.put(specification.getId(), queryQuality);
|
RatedRequest specification = specifications[responsePosition];
|
||||||
handleResponse();
|
if (response.isFailure() == false) {
|
||||||
|
SearchHit[] hits = response.getResponse().getHits().getHits();
|
||||||
|
EvalQueryQuality queryQuality = this.metric.evaluate(specification.getId(), hits, specification.getRatedDocs());
|
||||||
|
responseDetails.put(specification.getId(), queryQuality);
|
||||||
|
} else {
|
||||||
|
errors.put(specification.getId(), response.getFailure());
|
||||||
|
}
|
||||||
|
responsePosition++;
|
||||||
|
}
|
||||||
|
listener.onResponse(new RankEvalResponse(this.metric.combine(responseDetails.values()), responseDetails, this.errors));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onFailure(Exception exception) {
|
public void onFailure(Exception exception) {
|
||||||
errors.put(specification.getId(), exception);
|
listener.onFailure(exception);
|
||||||
handleResponse();
|
|
||||||
}
|
|
||||||
|
|
||||||
private void handleResponse() {
|
|
||||||
if (responseCounter.decrementAndGet() == 0) {
|
|
||||||
listener.onResponse(new RankEvalResponse(metric.combine(requestDetails.values()), requestDetails, errors));
|
|
||||||
} else {
|
|
||||||
if (!taskQueue.isEmpty()) {
|
|
||||||
RequestTask task = taskQueue.poll();
|
|
||||||
client.search(task.request, task.searchListener);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue