DATAES-907 - Track Total Hits not working when set to false.

Original PR: #515
This commit is contained in:
Peter-Josef Meisch 2020-08-28 23:06:42 +02:00 committed by GitHub
parent 4344a65dc2
commit ef1cbc35f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 162 additions and 16 deletions

View File

@ -238,7 +238,7 @@ public abstract class AbstractElasticsearchTemplate implements ElasticsearchOper
@Override @Override
public void delete(Query query, Class<?> clazz) { public void delete(Query query, Class<?> clazz) {
delete(query, getIndexCoordinatesFor(clazz)); delete(query, clazz, getIndexCoordinatesFor(clazz));
} }
@Override @Override

View File

@ -256,7 +256,7 @@ public class ElasticsearchRestTemplate extends AbstractElasticsearchTemplate {
Assert.notNull(query, "query must not be null"); Assert.notNull(query, "query must not be null");
Assert.notNull(index, "index must not be null"); Assert.notNull(index, "index must not be null");
final boolean trackTotalHits = query.getTrackTotalHits(); final Boolean trackTotalHits = query.getTrackTotalHits();
query.setTrackTotalHits(true); query.setTrackTotalHits(true);
SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index); SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index);
query.setTrackTotalHits(trackTotalHits); query.setTrackTotalHits(trackTotalHits);

View File

@ -276,7 +276,7 @@ public class ElasticsearchTemplate extends AbstractElasticsearchTemplate {
Assert.notNull(query, "query must not be null"); Assert.notNull(query, "query must not be null");
Assert.notNull(index, "index must not be null"); Assert.notNull(index, "index must not be null");
final boolean trackTotalHits = query.getTrackTotalHits(); final Boolean trackTotalHits = query.getTrackTotalHits();
query.setTrackTotalHits(true); query.setTrackTotalHits(true);
SearchRequestBuilder searchRequestBuilder = requestFactory.searchRequestBuilder(client, query, clazz, index); SearchRequestBuilder searchRequestBuilder = requestFactory.searchRequestBuilder(client, query, clazz, index);
query.setTrackTotalHits(trackTotalHits); query.setTrackTotalHits(trackTotalHits);

View File

@ -1152,8 +1152,10 @@ class RequestFactory {
} }
if (query.getTrackTotalHits()) { if (query.getTrackTotalHits() != null) {
sourceBuilder.trackTotalHits(query.getTrackTotalHits()); sourceBuilder.trackTotalHits(query.getTrackTotalHits());
} else if (query.getTrackTotalHitsUpTo() != null) {
sourceBuilder.trackTotalHitsUpTo(query.getTrackTotalHitsUpTo());
} }
if (StringUtils.hasLength(query.getRoute())) { if (StringUtils.hasLength(query.getRoute())) {
@ -1225,8 +1227,10 @@ class RequestFactory {
prepareNativeSearch(searchRequestBuilder, (NativeSearchQuery) query); prepareNativeSearch(searchRequestBuilder, (NativeSearchQuery) query);
} }
if (query.getTrackTotalHits()) { if (query.getTrackTotalHits() != null) {
searchRequestBuilder.setTrackTotalHits(query.getTrackTotalHits()); searchRequestBuilder.setTrackTotalHits(query.getTrackTotalHits());
} else if (query.getTrackTotalHitsUpTo() != null) {
searchRequestBuilder.setTrackTotalHitsUpTo(query.getTrackTotalHitsUpTo());
} }
if (StringUtils.hasLength(query.getRoute())) { if (StringUtils.hasLength(query.getRoute())) {

View File

@ -26,5 +26,9 @@ package org.springframework.data.elasticsearch.core;
*/ */
public enum TotalHitsRelation { public enum TotalHitsRelation {
EQUAL_TO, // EQUAL_TO, //
GREATER_THAN_OR_EQUAL_TO GREATER_THAN_OR_EQUAL_TO, //
/**
* @since 4.1
*/
OFF
} }

View File

@ -36,9 +36,9 @@ import org.springframework.util.Assert;
*/ */
public class SearchDocumentResponse { public class SearchDocumentResponse {
private long totalHits; private final long totalHits;
private String totalHitsRelation; private final String totalHitsRelation;
private float maxScore; private final float maxScore;
private final String scrollId; private final String scrollId;
private final List<SearchDocument> searchDocuments; private final List<SearchDocument> searchDocuments;
private final Aggregations aggregations; private final Aggregations aggregations;
@ -108,8 +108,17 @@ public class SearchDocumentResponse {
public static SearchDocumentResponse from(SearchHits searchHits, @Nullable String scrollId, public static SearchDocumentResponse from(SearchHits searchHits, @Nullable String scrollId,
@Nullable Aggregations aggregations) { @Nullable Aggregations aggregations) {
TotalHits responseTotalHits = searchHits.getTotalHits(); TotalHits responseTotalHits = searchHits.getTotalHits();
long totalHits = responseTotalHits.value;
String totalHitsRelation = responseTotalHits.relation.name(); long totalHits;
String totalHitsRelation;
if (responseTotalHits != null) {
totalHits = responseTotalHits.value;
totalHitsRelation = responseTotalHits.relation.name();
} else {
totalHits = searchHits.getHits().length;
totalHitsRelation = "OFF";
}
float maxScore = searchHits.getMaxScore(); float maxScore = searchHits.getMaxScore();

View File

@ -56,7 +56,8 @@ abstract class AbstractQuery implements Query {
@Nullable protected String preference; @Nullable protected String preference;
@Nullable protected Integer maxResults; @Nullable protected Integer maxResults;
@Nullable protected HighlightQuery highlightQuery; @Nullable protected HighlightQuery highlightQuery;
private boolean trackTotalHits = false; @Nullable private Boolean trackTotalHits;
@Nullable private Integer trackTotalHitsUpTo;
@Nullable private Duration scrollTime; @Nullable private Duration scrollTime;
@Override @Override
@ -220,15 +221,27 @@ abstract class AbstractQuery implements Query {
} }
@Override @Override
public void setTrackTotalHits(boolean trackTotalHits) { public void setTrackTotalHits(@Nullable Boolean trackTotalHits) {
this.trackTotalHits = trackTotalHits; this.trackTotalHits = trackTotalHits;
} }
@Override @Override
public boolean getTrackTotalHits() { @Nullable
public Boolean getTrackTotalHits() {
return trackTotalHits; return trackTotalHits;
} }
@Override
public void setTrackTotalHitsUpTo(@Nullable Integer trackTotalHitsUpTo) {
this.trackTotalHitsUpTo = trackTotalHitsUpTo;
}
@Override
@Nullable
public Integer getTrackTotalHitsUpTo() {
return trackTotalHitsUpTo;
}
@Nullable @Nullable
@Override @Override
public Duration getScrollTime() { public Duration getScrollTime() {

View File

@ -219,7 +219,7 @@ public interface Query {
* @param trackTotalHits the value to set. * @param trackTotalHits the value to set.
* @since 4.0 * @since 4.0
*/ */
void setTrackTotalHits(boolean trackTotalHits); void setTrackTotalHits(@Nullable Boolean trackTotalHits);
/** /**
* Sets the flag whether to set the Track_total_hits parameter on queries {@see <a href= * Sets the flag whether to set the Track_total_hits parameter on queries {@see <a href=
@ -229,7 +229,25 @@ public interface Query {
* @return the set value. * @return the set value.
* @since 4.0 * @since 4.0
*/ */
boolean getTrackTotalHits(); @Nullable
Boolean getTrackTotalHits();
/**
* Sets the maximum value up to which total hits are tracked. Only relevant if #getTrackTotalHits is {@literal null}
*
* @param trackTotalHitsUpTo max limit for trackTotalHits
* @since 4.1
*/
void setTrackTotalHitsUpTo(@Nullable Integer trackTotalHitsUpTo);
/**
* Gets the maximum value up to which total hits are tracked. Only relevant if #getTrackTotalHits is {@literal null}
*
* @return max limit for trackTotalHits
* @since 4.1
*/
@Nullable
Integer getTrackTotalHitsUpTo();
/** /**
* For queries that are used in delete request, these are internally handled by Elasticsearch as scroll/bulk delete * For queries that are used in delete request, these are internally handled by Elasticsearch as scroll/bulk delete

View File

@ -32,6 +32,7 @@ import org.elasticsearch.action.support.ActiveShardCount;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Id;
import org.springframework.data.elasticsearch.UncategorizedElasticsearchException; import org.springframework.data.elasticsearch.UncategorizedElasticsearchException;
@ -59,6 +60,7 @@ import org.springframework.test.context.ContextConfiguration;
*/ */
@SpringIntegrationTest @SpringIntegrationTest
@ContextConfiguration(classes = { ElasticsearchRestTemplateConfiguration.class }) @ContextConfiguration(classes = { ElasticsearchRestTemplateConfiguration.class })
@DisplayName("ElasticsearchRestTemplate")
public class ElasticsearchRestTemplateTests extends ElasticsearchTemplateTests { public class ElasticsearchRestTemplateTests extends ElasticsearchTemplateTests {
@Test @Test

View File

@ -44,7 +44,9 @@ import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.assertj.core.api.SoftAssertions;
import org.assertj.core.util.Lists; import org.assertj.core.util.Lists;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.IndicesOptions;
@ -61,6 +63,7 @@ import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.dao.OptimisticLockingFailureException;
@ -3511,6 +3514,97 @@ public abstract class ElasticsearchTemplateTests {
assertThatSeqNoPrimaryTermIsFilled(entity2); assertThatSeqNoPrimaryTermIsFilled(entity2);
} }
@Test // DATAES-907
@DisplayName("should track_total_hits with default value")
void shouldTrackTotalHitsWithDefaultValue() {
NativeSearchQuery queryAll = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).build();
operations.delete(queryAll, SampleEntity.class);
List<SampleEntity> entities = IntStream.rangeClosed(1, 15_000)
.mapToObj(i -> SampleEntity.builder().id("" + i).build()).collect(Collectors.toList());
operations.save(entities);
indexOperations.refresh();
queryAll.setTrackTotalHits(null);
SearchHits<SampleEntity> searchHits = operations.search(queryAll, SampleEntity.class);
SoftAssertions softly = new SoftAssertions();
softly.assertThat(searchHits.getTotalHits()).isEqualTo((long) RequestFactory.INDEX_MAX_RESULT_WINDOW);
softly.assertThat(searchHits.getTotalHitsRelation()).isEqualTo(TotalHitsRelation.GREATER_THAN_OR_EQUAL_TO);
softly.assertAll();
}
@Test // DATAES-907
@DisplayName("should track total hits")
void shouldTrackTotalHits() {
NativeSearchQuery queryAll = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).build();
operations.delete(queryAll, SampleEntity.class);
List<SampleEntity> entities = IntStream.rangeClosed(1, 15_000)
.mapToObj(i -> SampleEntity.builder().id("" + i).build()).collect(Collectors.toList());
operations.save(entities);
indexOperations.refresh();
queryAll.setTrackTotalHits(true);
queryAll.setTrackTotalHitsUpTo(12_345);
SearchHits<SampleEntity> searchHits = operations.search(queryAll, SampleEntity.class);
SoftAssertions softly = new SoftAssertions();
softly.assertThat(searchHits.getTotalHits()).isEqualTo(15_000L);
softly.assertThat(searchHits.getTotalHitsRelation()).isEqualTo(TotalHitsRelation.EQUAL_TO);
softly.assertAll();
}
@Test // DATAES-907
@DisplayName("should track total hits to specific value")
void shouldTrackTotalHitsToSpecificValue() {
NativeSearchQuery queryAll = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).build();
operations.delete(queryAll, SampleEntity.class);
List<SampleEntity> entities = IntStream.rangeClosed(1, 15_000)
.mapToObj(i -> SampleEntity.builder().id("" + i).build()).collect(Collectors.toList());
operations.save(entities);
indexOperations.refresh();
queryAll.setTrackTotalHits(null);
queryAll.setTrackTotalHitsUpTo(12_345);
SearchHits<SampleEntity> searchHits = operations.search(queryAll, SampleEntity.class);
SoftAssertions softly = new SoftAssertions();
softly.assertThat(searchHits.getTotalHits()).isEqualTo(12_345L);
softly.assertThat(searchHits.getTotalHitsRelation()).isEqualTo(TotalHitsRelation.GREATER_THAN_OR_EQUAL_TO);
softly.assertAll();
}
@Test
@DisplayName("should track total hits is off")
void shouldTrackTotalHitsIsOff() {
NativeSearchQuery queryAll = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).build();
operations.delete(queryAll, SampleEntity.class);
List<SampleEntity> entities = IntStream.rangeClosed(1, 15_000)
.mapToObj(i -> SampleEntity.builder().id("" + i).build()).collect(Collectors.toList());
operations.save(entities);
indexOperations.refresh();
queryAll.setTrackTotalHits(false);
queryAll.setTrackTotalHitsUpTo(12_345);
SearchHits<SampleEntity> searchHits = operations.search(queryAll, SampleEntity.class);
SoftAssertions softly = new SoftAssertions();
softly.assertThat(searchHits.getTotalHits()).isEqualTo(10_000L);
softly.assertThat(searchHits.getTotalHitsRelation()).isEqualTo(TotalHitsRelation.OFF);
softly.assertAll();
}
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor

View File

@ -34,6 +34,7 @@ import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.engine.DocumentMissingException; import org.elasticsearch.index.engine.DocumentMissingException;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Id;
@ -56,6 +57,7 @@ import org.springframework.test.context.ContextConfiguration;
*/ */
@SpringIntegrationTest @SpringIntegrationTest
@ContextConfiguration(classes = { ElasticsearchTemplateConfiguration.class }) @ContextConfiguration(classes = { ElasticsearchTemplateConfiguration.class })
@DisplayName("ElasticsearchTransportTemplate")
public class ElasticsearchTransportTemplateTests extends ElasticsearchTemplateTests { public class ElasticsearchTransportTemplateTests extends ElasticsearchTemplateTests {
@Autowired private Client client; @Autowired private Client client;