Adding listeners for search requests that collect results

This commit is contained in:
Christoph Büscher 2016-08-04 16:11:57 +02:00
parent e71e29b3a0
commit 795017ddfa
2 changed files with 55 additions and 38 deletions

View File

@ -27,8 +27,8 @@ import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.search.SearchHit;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Vector;
/**
* Classes implementing this interface provide a means to compute the quality of a result list
@ -76,7 +76,7 @@ public abstract class RankedListQualityMetric implements NamedWriteable {
return rc;
}
double combine(Vector<EvalQueryQuality> partialResults) {
double combine(Collection<EvalQueryQuality> partialResults) {
return partialResults.stream().mapToDouble(EvalQueryQuality::getQualityLevel).sum() / partialResults.size();
}
}

View File

@ -19,31 +19,25 @@
package org.elasticsearch.index.rankeval;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.AutoCreateIndex;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.action.SearchTransportService;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.controller.SearchPhaseController;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Vector;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Instances of this class execute a collection of search intents (read: user supplied query parameters) against a set of
@ -56,25 +50,13 @@ import java.util.Vector;
* set of search intents as averaged precision at n.
* */
public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequest, RankEvalResponse> {
private SearchPhaseController searchPhaseController;
private TransportService transportService;
private SearchTransportService searchTransportService;
private ClusterService clusterService;
private ActionFilters actionFilters;
private Client client;
@Inject
public TransportRankEvalAction(Settings settings, ThreadPool threadPool, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver, ClusterService clusterService, ScriptService scriptService,
AutoCreateIndex autoCreateIndex, Client client, TransportService transportService, SearchPhaseController searchPhaseController,
SearchTransportService searchTransportService, NamedWriteableRegistry namedWriteableRegistry) {
IndexNameExpressionResolver indexNameExpressionResolver, Client client, TransportService transportService) {
super(settings, RankEvalAction.NAME, threadPool, transportService, actionFilters, indexNameExpressionResolver,
RankEvalRequest::new);
this.searchPhaseController = searchPhaseController;
this.transportService = transportService;
this.searchTransportService = searchTransportService;
this.clusterService = clusterService;
this.actionFilters = actionFilters;
this.client = client;
}
@ -85,26 +67,61 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
Map<String, Collection<RatedDocumentKey>> unknownDocs = new HashMap<>();
Collection<QuerySpec> specifications = qualityTask.getSpecifications();
Vector<EvalQueryQuality> partialResults = new Vector<>(specifications.size());
for (QuerySpec spec : specifications) {
SearchSourceBuilder specRequest = spec.getTestRequest();
String[] indices = new String[spec.getIndices().size()];
spec.getIndices().toArray(indices);
AtomicInteger numberOfEvaluationQueries = new AtomicInteger(specifications.size());
Map<String, EvalQueryQuality> partialResults = new ConcurrentHashMap<>(specifications.size());
for (QuerySpec querySpecification : specifications) {
final RankEvalActionListener searchListener = new RankEvalActionListener(listener, qualityTask, querySpecification,
partialResults, unknownDocs, numberOfEvaluationQueries);
SearchSourceBuilder specRequest = querySpecification.getTestRequest();
String[] indices = new String[querySpecification.getIndices().size()];
querySpecification.getIndices().toArray(indices);
SearchRequest templatedRequest = new SearchRequest(indices, specRequest);
String[] types = new String[spec.getTypes().size()];
spec.getTypes().toArray(types);
String[] types = new String[querySpecification.getTypes().size()];
querySpecification.getTypes().toArray(types);
templatedRequest.types(types);
client.search(templatedRequest, searchListener);
}
}
ActionFuture<SearchResponse> response = client.search(templatedRequest);
SearchHits hits = response.actionGet().getHits();
public static class RankEvalActionListener implements ActionListener<SearchResponse> {
EvalQueryQuality queryQuality = metric.evaluate(hits.getHits(), spec.getRatedDocs());
partialResults.addElement(queryQuality);
unknownDocs.put(spec.getSpecId(), queryQuality.getUnknownDocs());
private ActionListener<RankEvalResponse> listener;
private QuerySpec specification;
private Map<String, EvalQueryQuality> partialResults;
private RankEvalSpec task;
private Map<String, Collection<RatedDocumentKey>> unknownDocs;
private AtomicInteger responseCounter;
public RankEvalActionListener(ActionListener<RankEvalResponse> listener, RankEvalSpec task, QuerySpec specification,
Map<String, EvalQueryQuality> partialResults, Map<String, Collection<RatedDocumentKey>> unknownDocs,
AtomicInteger responseCounter) {
this.listener = listener;
this.task = task;
this.specification = specification;
this.partialResults = partialResults;
this.unknownDocs = unknownDocs;
this.responseCounter = responseCounter;
}
// TODO add other statistics like micro/macro avg?
RankEvalResponse response = new RankEvalResponse(qualityTask.getTaskId(), metric.combine(partialResults), unknownDocs);
listener.onResponse(response);
@Override
public void onResponse(SearchResponse searchResponse) {
SearchHits hits = searchResponse.getHits();
EvalQueryQuality queryQuality = task.getEvaluator().evaluate(hits.getHits(), specification.getRatedDocs());
partialResults.put(specification.getSpecId(), queryQuality);
unknownDocs.put(specification.getSpecId(), queryQuality.getUnknownDocs());
if (responseCounter.decrementAndGet() == 0) {
// TODO add other statistics like micro/macro avg?
listener.onResponse(
new RankEvalResponse(task.getTaskId(), task.getEvaluator().combine(partialResults.values()), unknownDocs));
}
}
@Override
public void onFailure(Exception exception) {
// TODO this fails the complete request. Investigate if maybe we want to collect errors and still return partial result.
this.listener.onFailure(exception);
}
}
}