[ML] Ensure total hits are tracked (#36374)

This is in preparation of the anticipated change
that will disable accurate total hits tracking in
searches.
This commit is contained in:
Dimitris Athanasiou 2018-12-07 18:01:37 +00:00 committed by GitHub
parent 2df4bd1f81
commit b8dba16376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 21 additions and 6 deletions

View File

@ -343,6 +343,7 @@ public class TransportDeleteJobAction extends TransportMasterNodeAction<DeleteJo
} else {
SearchSourceBuilder source = new SearchSourceBuilder()
.size(1)
.trackTotalHits(true)
.query(QueryBuilders.boolQuery().filter(
QueryBuilders.boolQuery().mustNot(QueryBuilders.termQuery(Job.ID.getPreferredName(), jobId))));

View File

@ -257,6 +257,7 @@ public class TransportGetOverallBucketsAction extends HandledTransportAction<Get
.start(startTime)
.end(endTime)
.build();
searchSourceBuilder.trackTotalHits(true);
SearchRequest searchRequest = new SearchRequest(indices);
searchRequest.indicesOptions(MlIndicesUtils.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS));

View File

@ -237,7 +237,8 @@ public class ChunkedDataExtractor implements DataExtractor {
private SearchRequestBuilder rangeSearchRequest() {
return new SearchRequestBuilder(client, SearchAction.INSTANCE)
.setIndices(context.indices)
.setSource(rangeSearchBuilder());
.setSource(rangeSearchBuilder())
.setTrackTotalHits(true);
}
private RollupSearchAction.RequestBuilder rollupRangeSearchRequest() {

View File

@ -98,6 +98,7 @@ public abstract class BatchedDocumentsIterator<T> {
.size(BATCH_SIZE)
.query(getQuery())
.fetchSource(shouldFetchSource())
.trackTotalHits(true)
.sort(SortBuilders.fieldSort(ElasticsearchMappings.ES_DOC)));
SearchResponse searchResponse = client.search(searchRequest).actionGet();

View File

@ -522,7 +522,7 @@ public class JobResultsProvider {
String indexName = AnomalyDetectorsIndex.jobResultsAliasedName(jobId);
SearchRequest searchRequest = new SearchRequest(indexName);
searchRequest.source(query.build());
searchRequest.source(query.build().trackTotalHits(true));
searchRequest.indicesOptions(MlIndicesUtils.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS));
executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest,
@ -658,6 +658,7 @@ public class JobResultsProvider {
} else {
throw new IllegalStateException("Both categoryId and pageParams are not specified");
}
sourceBuilder.trackTotalHits(true);
searchRequest.source(sourceBuilder);
executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest,
ActionListener.<SearchResponse>wrap(searchResponse -> {
@ -706,7 +707,7 @@ public class JobResultsProvider {
SearchSourceBuilder searchSourceBuilder = recordsQueryBuilder.build();
SearchRequest searchRequest = new SearchRequest(indexName);
searchRequest.indicesOptions(MlIndicesUtils.addIgnoreUnavailable(searchRequest.indicesOptions()));
searchRequest.source(recordsQueryBuilder.build());
searchRequest.source(recordsQueryBuilder.build().trackTotalHits(true));
LOGGER.trace("ES API CALL: search all of records from index {} with query {}", indexName, searchSourceBuilder);
executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest,
@ -756,7 +757,7 @@ public class JobResultsProvider {
searchRequest.indicesOptions(MlIndicesUtils.addIgnoreUnavailable(searchRequest.indicesOptions()));
FieldSortBuilder sb = query.getSortField() == null ? SortBuilders.fieldSort(ElasticsearchMappings.ES_DOC)
: new FieldSortBuilder(query.getSortField()).order(query.isSortDescending() ? SortOrder.DESC : SortOrder.ASC);
searchRequest.source(new SearchSourceBuilder().query(qb).from(query.getFrom()).size(query.getSize()).sort(sb));
searchRequest.source(new SearchSourceBuilder().query(qb).from(query.getFrom()).size(query.getSize()).sort(sb).trackTotalHits(true));
executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest,
ActionListener.<SearchResponse>wrap(response -> {
@ -877,6 +878,7 @@ public class JobResultsProvider {
sourceBuilder.query(finalQuery);
sourceBuilder.from(from);
sourceBuilder.size(size);
sourceBuilder.trackTotalHits(true);
searchRequest.source(sourceBuilder);
executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest,
ActionListener.<SearchResponse>wrap(searchResponse -> {
@ -901,6 +903,7 @@ public class JobResultsProvider {
.setIndicesOptions(MlIndicesUtils.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS))
.setQuery(new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), ModelPlot.RESULT_TYPE_VALUE))
.setFrom(from).setSize(size)
.setTrackTotalHits(true)
.get();
}
@ -1088,7 +1091,8 @@ public class JobResultsProvider {
public void scheduledEvents(ScheduledEventsQueryBuilder query, ActionListener<QueryPage<ScheduledEvent>> handler) {
SearchRequestBuilder request = client.prepareSearch(MlMetaIndex.INDEX_NAME)
.setIndicesOptions(IndicesOptions.lenientExpandOpen())
.setSource(query.build());
.setSource(query.build())
.setTrackTotalHits(true);
executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, request.request(),
ActionListener.<SearchResponse>wrap(
@ -1138,6 +1142,7 @@ public class JobResultsProvider {
sourceBuilder.aggregation(
AggregationBuilders.terms(ForecastStats.Fields.STATUSES).field(ForecastRequestStats.STATUS.getPreferredName()));
sourceBuilder.size(0);
sourceBuilder.trackTotalHits(true);
searchRequest.source(sourceBuilder);
@ -1211,6 +1216,7 @@ public class JobResultsProvider {
public void calendars(CalendarQueryBuilder queryBuilder, ActionListener<QueryPage<Calendar>> listener) {
SearchRequest searchRequest = client.prepareSearch(MlMetaIndex.INDEX_NAME)
.setIndicesOptions(IndicesOptions.lenientExpandOpen())
.setTrackTotalHits(true)
.setSource(queryBuilder.build()).request();
executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest,
@ -1222,7 +1228,7 @@ public class JobResultsProvider {
calendars.add(parseSearchHit(hit, Calendar.LENIENT_PARSER, listener::onFailure).build());
}
listener.onResponse(new QueryPage<Calendar>(calendars, response.getHits().getTotalHits().value,
listener.onResponse(new QueryPage<>(calendars, response.getHits().getTotalHits().value,
Calendar.RESULTS_FIELD));
},
listener::onFailure)

View File

@ -81,6 +81,7 @@ public class ExpiredForecastsRemover implements MlDataRemover {
.filter(QueryBuilders.termQuery(Result.RESULT_TYPE.getPreferredName(), ForecastRequestStats.RESULT_TYPE_VALUE))
.filter(QueryBuilders.existsQuery(ForecastRequestStats.EXPIRY_TIME.getPreferredName())));
source.size(MAX_FORECASTS);
source.trackTotalHits(true);
SearchRequest searchRequest = new SearchRequest(RESULTS_INDEX_PATTERN);
searchRequest.source(source);

View File

@ -139,6 +139,7 @@ public class ChunkedDataExtractorTests extends ESTestCase {
"\"format\":\"epoch_millis\",\"boost\":1.0}}}]"));
assertThat(searchRequest, containsString("\"aggregations\":{\"earliest_time\":{\"min\":{\"field\":\"time\"}}," +
"\"latest_time\":{\"max\":{\"field\":\"time\"}}}}"));
assertThat(searchRequest, not(containsString("\"track_total_hits\":false")));
assertThat(searchRequest, not(containsString("\"sort\"")));
}
@ -178,6 +179,7 @@ public class ChunkedDataExtractorTests extends ESTestCase {
"\"format\":\"epoch_millis\",\"boost\":1.0}}}]"));
assertThat(searchRequest, containsString("\"aggregations\":{\"earliest_time\":{\"min\":{\"field\":\"time\"}}," +
"\"latest_time\":{\"max\":{\"field\":\"time\"}}}}"));
assertThat(searchRequest, not(containsString("\"track_total_hits\":false")));
assertThat(searchRequest, not(containsString("\"sort\"")));
}

View File

@ -31,6 +31,7 @@ import java.util.List;
import java.util.NoSuchElementException;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -138,6 +139,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
assertThat(searchRequest.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
assertThat(searchRequest.types().length, equalTo(0));
assertThat(searchRequest.source().query(), equalTo(QueryBuilders.matchAllQuery()));
assertThat(searchRequest.source().trackTotalHits(), is(true));
}
private void assertSearchScrollRequests(int expectedCount) {