From d55cb00d4576e1c7b95f121438dfeb1371ea7232 Mon Sep 17 00:00:00 2001 From: Peter-Josef Meisch Date: Fri, 10 Jan 2020 15:52:03 +0100 Subject: [PATCH] DATAES-727 - Use track_total_hits parameter for count queries. Original PR: #379 --- .../DefaultReactiveElasticsearchClient.java | 8 ++-- .../reactive/ReactiveElasticsearchClient.java | 23 ++++++----- .../core/ElasticsearchRestTemplate.java | 5 +++ .../core/ElasticsearchTemplate.java | 4 ++ .../core/ReactiveElasticsearchTemplate.java | 38 +++++-------------- 5 files changed, 34 insertions(+), 44 deletions(-) diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java index db3b3216e..8c1466cc4 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java @@ -79,7 +79,6 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Request; import org.elasticsearch.client.core.CountRequest; -import org.elasticsearch.client.core.CountResponse; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -336,9 +335,10 @@ public class DefaultReactiveElasticsearchClient implements ReactiveElasticsearch * @see org.springframework.data.elasticsearch.client.reactive.ReactiveElasticsearchClient#count(org.springframework.http.HttpHeaders, org.elasticsearch.action.search.SearchRequest) */ @Override - public Mono count(HttpHeaders headers, CountRequest countRequest) { - return sendRequest(countRequest, RequestCreator.count(), CountResponse.class, headers) // - .map(CountResponse::getCount) // + public Mono count(HttpHeaders headers, SearchRequest searchRequest) { + return sendRequest(searchRequest, RequestCreator.search(), SearchResponse.class, headers) // + .map(SearchResponse::getHits) // + .map(searchHits -> searchHits.getTotalHits().value) // .next(); } diff --git a/src/main/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClient.java b/src/main/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClient.java index 6480074a3..914a43395 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClient.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/reactive/ReactiveElasticsearchClient.java @@ -43,7 +43,6 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateResponse; -import org.elasticsearch.client.core.CountRequest; import org.elasticsearch.index.get.GetResult; import org.elasticsearch.index.reindex.BulkByScrollResponse; import org.elasticsearch.index.reindex.DeleteByQueryRequest; @@ -342,37 +341,37 @@ public interface ReactiveElasticsearchClient { * @return the {@link Mono} emitting the count result. * @since 4.0 */ - default Mono count(Consumer consumer) { + default Mono count(Consumer consumer) { - CountRequest countRequest = new CountRequest(); - consumer.accept(countRequest); - return count(countRequest); + SearchRequest searchRequest = new SearchRequest(); + consumer.accept(searchRequest); + return count(searchRequest); } /** - * Execute a {@link SearchRequest} against the {@literal count} API. + * Execute a {@link SearchRequest} against the {@literal search} API. * - * @param countRequest must not be {@literal null}. + * @param searchRequest must not be {@literal null}. * @see Count API on * elastic.co * @return the {@link Mono} emitting the count result. * @since 4.0 */ - default Mono count(CountRequest countRequest) { - return count(HttpHeaders.EMPTY, countRequest); + default Mono count(SearchRequest searchRequest) { + return count(HttpHeaders.EMPTY, searchRequest); } /** - * Execute a {@link SearchRequest} against the {@literal count} API. + * Execute a {@link SearchRequest} against the {@literal search} API. * * @param headers Use {@link HttpHeaders} to provide eg. authentication data. Must not be {@literal null}. - * @param countRequest must not be {@literal null}. + * @param searchRequest must not be {@literal null}. * @see Count API on * elastic.co * @return the {@link Mono} emitting the count result. * @since 4.0 */ - Mono count(HttpHeaders headers, CountRequest countRequest); + Mono count(HttpHeaders headers, SearchRequest searchRequest); /** * Execute a {@link SearchRequest} against the {@literal search} API. diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java index dd4ade833..314680272 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java @@ -212,8 +212,13 @@ public class ElasticsearchRestTemplate extends AbstractElasticsearchTemplate { @Override public long count(Query query, Class clazz, IndexCoordinates index) { + Assert.notNull(query, "query must not be null"); Assert.notNull(index, "index must not be null"); + + final boolean trackTotalHits = query.getTrackTotalHits(); + query.setTrackTotalHits(true); SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index); + query.setTrackTotalHits(trackTotalHits); searchRequest.source().size(0); diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java index f6793032e..24ce4df8e 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java @@ -185,9 +185,13 @@ public class ElasticsearchTemplate extends AbstractElasticsearchTemplate { @Override public long count(Query query, @Nullable Class clazz, IndexCoordinates index) { + Assert.notNull(query, "query must not be null"); Assert.notNull(index, "index must not be null"); + final boolean trackTotalHits = query.getTrackTotalHits(); + query.setTrackTotalHits(true); SearchRequestBuilder searchRequestBuilder = requestFactory.searchRequestBuilder(client, query, clazz, index); + query.setTrackTotalHits(trackTotalHits); searchRequestBuilder.setSize(0); return SearchHitsUtil.getTotalCount(getSearchResponse(searchRequestBuilder).getHits()); diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java index 8d36358c3..ae2357c78 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java @@ -435,10 +435,7 @@ public class ReactiveElasticsearchTemplate implements ReactiveElasticsearchOpera return Flux.defer(() -> { SearchRequest request = requestFactory.searchRequest(query, clazz, index); - - if (indicesOptions != null) { - request.indicesOptions(indicesOptions); - } + request = prepareSearchRequest(request); if (query.getPageable().isPaged() || query.isLimiting()) { return doFind(request); @@ -455,17 +452,18 @@ public class ReactiveElasticsearchTemplate implements ReactiveElasticsearchOpera @Override public Mono count(Query query, Class entityType, IndexCoordinates index) { - return doCount(query, getPersistentEntityFor(entityType), index); + return doCount(query, entityType, index); } - private Mono doCount(Query query, ElasticsearchPersistentEntity entity, IndexCoordinates index) { + private Mono doCount(Query query, Class entityType, IndexCoordinates index) { return Mono.defer(() -> { - CountRequest countRequest = buildCountRequest(query, entity, index); - CountRequest request = prepareCountRequest(countRequest); + SearchRequest request = requestFactory.searchRequest(query, entityType, index); + request = prepareSearchRequest(request); + request.source().size(0); + request.source().trackTotalHits(true); return doCount(request); }); - } private CountRequest buildCountRequest(Query query, ElasticsearchPersistentEntity entity, IndexCoordinates index) { @@ -524,17 +522,17 @@ public class ReactiveElasticsearchTemplate implements ReactiveElasticsearchOpera /** * Customization hook on the actual execution result {@link Publisher}.
* - * @param request the already prepared {@link CountRequest} ready to be executed. + * @param request the already prepared {@link SearchRequest} ready to be executed. * @return a {@link Mono} emitting the result of the operation. */ - protected Mono doCount(CountRequest request) { + protected Mono doCount(SearchRequest request) { if (QUERY_LOGGER.isDebugEnabled()) { QUERY_LOGGER.debug("Executing doCount: {}", request); } return Mono.from(execute(client -> client.count(request))) // - .onErrorResume(NoSuchIndexException.class, it -> Mono.empty()); + .onErrorResume(NoSuchIndexException.class, it -> Mono.just(0L)); } /** @@ -608,22 +606,6 @@ public class ReactiveElasticsearchTemplate implements ReactiveElasticsearchOpera return mappedSort; } - /** - * Customization hook to modify a generated {@link SearchRequest} prior to its execution. Eg. by setting the - * {@link SearchRequest#indicesOptions(IndicesOptions) indices options} if applicable. - * - * @param request the generated {@link CountRequest}. - * @return never {@literal null}. - */ - protected CountRequest prepareCountRequest(CountRequest request) { - - if (indicesOptions == null) { - return request; - } - - return request.indicesOptions(indicesOptions); - } - /** * Customization hook to modify a generated {@link SearchRequest} prior to its execution. Eg. by setting the * {@link SearchRequest#indicesOptions(IndicesOptions) indices options} if applicable.