diff --git a/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java index 2e5b29519..bef21bfc6 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java @@ -238,7 +238,7 @@ public abstract class AbstractElasticsearchTemplate implements ElasticsearchOper @Override public void delete(Query query, Class clazz) { - delete(query, getIndexCoordinatesFor(clazz)); + delete(query, clazz, getIndexCoordinatesFor(clazz)); } @Override diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java index cd3358761..450f4ef09 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java @@ -256,7 +256,7 @@ public class ElasticsearchRestTemplate extends AbstractElasticsearchTemplate { Assert.notNull(query, "query must not be null"); Assert.notNull(index, "index must not be null"); - final boolean trackTotalHits = query.getTrackTotalHits(); + final Boolean trackTotalHits = query.getTrackTotalHits(); query.setTrackTotalHits(true); SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index); query.setTrackTotalHits(trackTotalHits); diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java index e385fb341..c6442760d 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java @@ -276,7 +276,7 @@ public class ElasticsearchTemplate extends AbstractElasticsearchTemplate { Assert.notNull(query, "query must not be null"); Assert.notNull(index, "index must not be null"); - final boolean trackTotalHits = query.getTrackTotalHits(); + final Boolean trackTotalHits = query.getTrackTotalHits(); query.setTrackTotalHits(true); SearchRequestBuilder searchRequestBuilder = requestFactory.searchRequestBuilder(client, query, clazz, index); query.setTrackTotalHits(trackTotalHits); diff --git a/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java b/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java index 9d90d5272..5ba852168 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java @@ -1152,8 +1152,10 @@ class RequestFactory { } - if (query.getTrackTotalHits()) { + if (query.getTrackTotalHits() != null) { sourceBuilder.trackTotalHits(query.getTrackTotalHits()); + } else if (query.getTrackTotalHitsUpTo() != null) { + sourceBuilder.trackTotalHitsUpTo(query.getTrackTotalHitsUpTo()); } if (StringUtils.hasLength(query.getRoute())) { @@ -1225,8 +1227,10 @@ class RequestFactory { prepareNativeSearch(searchRequestBuilder, (NativeSearchQuery) query); } - if (query.getTrackTotalHits()) { + if (query.getTrackTotalHits() != null) { searchRequestBuilder.setTrackTotalHits(query.getTrackTotalHits()); + } else if (query.getTrackTotalHitsUpTo() != null) { + searchRequestBuilder.setTrackTotalHitsUpTo(query.getTrackTotalHitsUpTo()); } if (StringUtils.hasLength(query.getRoute())) { diff --git a/src/main/java/org/springframework/data/elasticsearch/core/TotalHitsRelation.java b/src/main/java/org/springframework/data/elasticsearch/core/TotalHitsRelation.java index 14f069de0..ce532015f 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/TotalHitsRelation.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/TotalHitsRelation.java @@ -26,5 +26,9 @@ package org.springframework.data.elasticsearch.core; */ public enum TotalHitsRelation { EQUAL_TO, // - GREATER_THAN_OR_EQUAL_TO + GREATER_THAN_OR_EQUAL_TO, // + /** + * @since 4.1 + */ + OFF } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/document/SearchDocumentResponse.java b/src/main/java/org/springframework/data/elasticsearch/core/document/SearchDocumentResponse.java index cadbbef45..5b5854a42 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/document/SearchDocumentResponse.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/document/SearchDocumentResponse.java @@ -36,9 +36,9 @@ import org.springframework.util.Assert; */ public class SearchDocumentResponse { - private long totalHits; - private String totalHitsRelation; - private float maxScore; + private final long totalHits; + private final String totalHitsRelation; + private final float maxScore; private final String scrollId; private final List searchDocuments; private final Aggregations aggregations; @@ -108,8 +108,17 @@ public class SearchDocumentResponse { public static SearchDocumentResponse from(SearchHits searchHits, @Nullable String scrollId, @Nullable Aggregations aggregations) { TotalHits responseTotalHits = searchHits.getTotalHits(); - long totalHits = responseTotalHits.value; - String totalHitsRelation = responseTotalHits.relation.name(); + + long totalHits; + String totalHitsRelation; + + if (responseTotalHits != null) { + totalHits = responseTotalHits.value; + totalHitsRelation = responseTotalHits.relation.name(); + } else { + totalHits = searchHits.getHits().length; + totalHitsRelation = "OFF"; + } float maxScore = searchHits.getMaxScore(); diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java index 3bb80a398..e0e284831 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java @@ -56,7 +56,8 @@ abstract class AbstractQuery implements Query { @Nullable protected String preference; @Nullable protected Integer maxResults; @Nullable protected HighlightQuery highlightQuery; - private boolean trackTotalHits = false; + @Nullable private Boolean trackTotalHits; + @Nullable private Integer trackTotalHitsUpTo; @Nullable private Duration scrollTime; @Override @@ -220,15 +221,27 @@ abstract class AbstractQuery implements Query { } @Override - public void setTrackTotalHits(boolean trackTotalHits) { + public void setTrackTotalHits(@Nullable Boolean trackTotalHits) { this.trackTotalHits = trackTotalHits; } @Override - public boolean getTrackTotalHits() { + @Nullable + public Boolean getTrackTotalHits() { return trackTotalHits; } + @Override + public void setTrackTotalHitsUpTo(@Nullable Integer trackTotalHitsUpTo) { + this.trackTotalHitsUpTo = trackTotalHitsUpTo; + } + + @Override + @Nullable + public Integer getTrackTotalHitsUpTo() { + return trackTotalHitsUpTo; + } + @Nullable @Override public Duration getScrollTime() { diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java b/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java index 47f5d216c..e5c4b907a 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java @@ -219,7 +219,7 @@ public interface Query { * @param trackTotalHits the value to set. * @since 4.0 */ - void setTrackTotalHits(boolean trackTotalHits); + void setTrackTotalHits(@Nullable Boolean trackTotalHits); /** * Sets the flag whether to set the Track_total_hits parameter on queries {@see entities = IntStream.rangeClosed(1, 15_000) + .mapToObj(i -> SampleEntity.builder().id("" + i).build()).collect(Collectors.toList()); + + operations.save(entities); + indexOperations.refresh(); + + queryAll.setTrackTotalHits(null); + SearchHits searchHits = operations.search(queryAll, SampleEntity.class); + + SoftAssertions softly = new SoftAssertions(); + softly.assertThat(searchHits.getTotalHits()).isEqualTo((long) RequestFactory.INDEX_MAX_RESULT_WINDOW); + softly.assertThat(searchHits.getTotalHitsRelation()).isEqualTo(TotalHitsRelation.GREATER_THAN_OR_EQUAL_TO); + softly.assertAll(); + } + + @Test // DATAES-907 + @DisplayName("should track total hits") + void shouldTrackTotalHits() { + + NativeSearchQuery queryAll = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).build(); + operations.delete(queryAll, SampleEntity.class); + + List entities = IntStream.rangeClosed(1, 15_000) + .mapToObj(i -> SampleEntity.builder().id("" + i).build()).collect(Collectors.toList()); + + operations.save(entities); + indexOperations.refresh(); + + queryAll.setTrackTotalHits(true); + queryAll.setTrackTotalHitsUpTo(12_345); + SearchHits searchHits = operations.search(queryAll, SampleEntity.class); + + SoftAssertions softly = new SoftAssertions(); + softly.assertThat(searchHits.getTotalHits()).isEqualTo(15_000L); + softly.assertThat(searchHits.getTotalHitsRelation()).isEqualTo(TotalHitsRelation.EQUAL_TO); + softly.assertAll(); + } + + @Test // DATAES-907 + @DisplayName("should track total hits to specific value") + void shouldTrackTotalHitsToSpecificValue() { + + NativeSearchQuery queryAll = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).build(); + operations.delete(queryAll, SampleEntity.class); + + List entities = IntStream.rangeClosed(1, 15_000) + .mapToObj(i -> SampleEntity.builder().id("" + i).build()).collect(Collectors.toList()); + + operations.save(entities); + indexOperations.refresh(); + + queryAll.setTrackTotalHits(null); + queryAll.setTrackTotalHitsUpTo(12_345); + SearchHits searchHits = operations.search(queryAll, SampleEntity.class); + + SoftAssertions softly = new SoftAssertions(); + softly.assertThat(searchHits.getTotalHits()).isEqualTo(12_345L); + softly.assertThat(searchHits.getTotalHitsRelation()).isEqualTo(TotalHitsRelation.GREATER_THAN_OR_EQUAL_TO); + softly.assertAll(); + } + + @Test + @DisplayName("should track total hits is off") + void shouldTrackTotalHitsIsOff() { + + NativeSearchQuery queryAll = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).build(); + operations.delete(queryAll, SampleEntity.class); + + List entities = IntStream.rangeClosed(1, 15_000) + .mapToObj(i -> SampleEntity.builder().id("" + i).build()).collect(Collectors.toList()); + + operations.save(entities); + indexOperations.refresh(); + + queryAll.setTrackTotalHits(false); + queryAll.setTrackTotalHitsUpTo(12_345); + SearchHits searchHits = operations.search(queryAll, SampleEntity.class); + + SoftAssertions softly = new SoftAssertions(); + softly.assertThat(searchHits.getTotalHits()).isEqualTo(10_000L); + softly.assertThat(searchHits.getTotalHitsRelation()).isEqualTo(TotalHitsRelation.OFF); + softly.assertAll(); + } + @Data @NoArgsConstructor @AllArgsConstructor diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTransportTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTransportTemplateTests.java index d2379cafe..5a214ea97 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTransportTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTransportTemplateTests.java @@ -34,6 +34,7 @@ import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.engine.DocumentMissingException; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.annotation.Id; @@ -56,6 +57,7 @@ import org.springframework.test.context.ContextConfiguration; */ @SpringIntegrationTest @ContextConfiguration(classes = { ElasticsearchTemplateConfiguration.class }) +@DisplayName("ElasticsearchTransportTemplate") public class ElasticsearchTransportTemplateTests extends ElasticsearchTemplateTests { @Autowired private Client client;