From f989cf873b0e2a5e60044ffa1af42b77b05e9012 Mon Sep 17 00:00:00 2001 From: Peter-Josef Meisch Date: Wed, 29 Jul 2020 09:49:55 +0200 Subject: [PATCH] DATAES-891 - Returning a Stream from a Query annotated repository method crashes. Original PR: #497 --- .../AbstractElasticsearchRepositoryQuery.java | 3 +- .../query/ElasticsearchPartQuery.java | 2 -- .../query/ElasticsearchStringQuery.java | 9 ++++++ .../CustomMethodRepositoryBaseTests.java | 29 +++++++++++++++++++ 4 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/springframework/data/elasticsearch/repository/query/AbstractElasticsearchRepositoryQuery.java b/src/main/java/org/springframework/data/elasticsearch/repository/query/AbstractElasticsearchRepositoryQuery.java index 88907be50..127d2cc94 100644 --- a/src/main/java/org/springframework/data/elasticsearch/repository/query/AbstractElasticsearchRepositoryQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/repository/query/AbstractElasticsearchRepositoryQuery.java @@ -28,11 +28,12 @@ import org.springframework.data.repository.query.RepositoryQuery; public abstract class AbstractElasticsearchRepositoryQuery implements RepositoryQuery { + protected static final int DEFAULT_STREAM_BATCH_SIZE = 500; protected ElasticsearchQueryMethod queryMethod; protected ElasticsearchOperations elasticsearchOperations; public AbstractElasticsearchRepositoryQuery(ElasticsearchQueryMethod queryMethod, - ElasticsearchOperations elasticsearchOperations) { + ElasticsearchOperations elasticsearchOperations) { this.queryMethod = queryMethod; this.elasticsearchOperations = elasticsearchOperations; } diff --git a/src/main/java/org/springframework/data/elasticsearch/repository/query/ElasticsearchPartQuery.java b/src/main/java/org/springframework/data/elasticsearch/repository/query/ElasticsearchPartQuery.java index 151740a7a..647941885 100644 --- a/src/main/java/org/springframework/data/elasticsearch/repository/query/ElasticsearchPartQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/repository/query/ElasticsearchPartQuery.java @@ -43,8 +43,6 @@ import org.springframework.util.ClassUtils; */ public class ElasticsearchPartQuery extends AbstractElasticsearchRepositoryQuery { - private static final int DEFAULT_STREAM_BATCH_SIZE = 500; - private final PartTree tree; private final ElasticsearchConverter elasticsearchConverter; private final MappingContext mappingContext; diff --git a/src/main/java/org/springframework/data/elasticsearch/repository/query/ElasticsearchStringQuery.java b/src/main/java/org/springframework/data/elasticsearch/repository/query/ElasticsearchStringQuery.java index f6863067d..db1905dfd 100644 --- a/src/main/java/org/springframework/data/elasticsearch/repository/query/ElasticsearchStringQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/repository/query/ElasticsearchStringQuery.java @@ -19,6 +19,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import org.springframework.core.convert.support.GenericConversionService; +import org.springframework.data.domain.PageRequest; import org.springframework.data.elasticsearch.core.ElasticsearchOperations; import org.springframework.data.elasticsearch.core.SearchHitSupport; import org.springframework.data.elasticsearch.core.SearchHits; @@ -26,6 +27,7 @@ import org.springframework.data.elasticsearch.core.convert.DateTimeConverters; import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; import org.springframework.data.elasticsearch.core.query.StringQuery; import org.springframework.data.repository.query.ParametersParameterAccessor; +import org.springframework.data.util.StreamUtils; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.NumberUtils; @@ -88,6 +90,13 @@ public class ElasticsearchStringQuery extends AbstractElasticsearchRepositoryQue stringQuery.setPageable(accessor.getPageable()); SearchHits searchHits = elasticsearchOperations.search(stringQuery, clazz, index); result = SearchHitSupport.page(searchHits, stringQuery.getPageable()); + } else if (queryMethod.isStreamQuery()) { + if (accessor.getPageable().isUnpaged()) { + stringQuery.setPageable(PageRequest.of(0, DEFAULT_STREAM_BATCH_SIZE)); + } else { + stringQuery.setPageable(accessor.getPageable()); + } + result = StreamUtils.createStreamFromIterator(elasticsearchOperations.searchForStream(stringQuery, clazz, index)); } else if (queryMethod.isCollectionQuery()) { if (accessor.getPageable().isPaged()) { stringQuery.setPageable(accessor.getPageable()); diff --git a/src/test/java/org/springframework/data/elasticsearch/repositories/custommethod/CustomMethodRepositoryBaseTests.java b/src/test/java/org/springframework/data/elasticsearch/repositories/custommethod/CustomMethodRepositoryBaseTests.java index 9bffc6c0d..d2c2b3f71 100644 --- a/src/test/java/org/springframework/data/elasticsearch/repositories/custommethod/CustomMethodRepositoryBaseTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/repositories/custommethod/CustomMethodRepositoryBaseTests.java @@ -1544,6 +1544,28 @@ public abstract class CustomMethodRepositoryBaseTests { return entities; } + @Test // DATAES-891 + void shouldStreamEntitiesWithQueryAnnotatedMethod() { + List entities = createSampleEntities("abc", 20); + repository.saveAll(entities); + + Stream stream = streamingRepository.streamEntitiesByType("abc"); + + long count = stream.peek(sampleEntity -> assertThat(sampleEntity).isInstanceOf(SampleEntity.class)).count(); + assertThat(count).isEqualTo(20); + } + + @Test // DATAES-891 + void shouldStreamSearchHitsWithQueryAnnotatedMethod() { + List entities = createSampleEntities("abc", 20); + repository.saveAll(entities); + + Stream> stream = streamingRepository.streamSearchHitsByType("abc"); + + long count = stream.peek(sampleEntity -> assertThat(sampleEntity).isInstanceOf(SearchHit.class)).count(); + assertThat(count).isEqualTo(20); + } + @Data @NoArgsConstructor @AllArgsConstructor @@ -1687,5 +1709,12 @@ public abstract class CustomMethodRepositoryBaseTests { Stream findByType(String type); Stream findByType(String type, Pageable pageable); + + @Query("{\"bool\": {\"must\": [{\"term\": {\"type\": \"?0\"}}]}}") + Stream streamEntitiesByType(String type); + + @Query("{\"bool\": {\"must\": [{\"term\": {\"type\": \"?0\"}}]}}") + Stream> streamSearchHitsByType(String type); + } }