DATAES-891 - Returning a Stream from a Query annotated repository method crashes.

Original PR: #497
This commit is contained in:
Peter-Josef Meisch 2020-07-29 09:49:55 +02:00 committed by GitHub
parent fe458612e9
commit f989cf873b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 3 deletions

View File

@ -28,11 +28,12 @@ import org.springframework.data.repository.query.RepositoryQuery;
public abstract class AbstractElasticsearchRepositoryQuery implements RepositoryQuery { public abstract class AbstractElasticsearchRepositoryQuery implements RepositoryQuery {
protected static final int DEFAULT_STREAM_BATCH_SIZE = 500;
protected ElasticsearchQueryMethod queryMethod; protected ElasticsearchQueryMethod queryMethod;
protected ElasticsearchOperations elasticsearchOperations; protected ElasticsearchOperations elasticsearchOperations;
public AbstractElasticsearchRepositoryQuery(ElasticsearchQueryMethod queryMethod, public AbstractElasticsearchRepositoryQuery(ElasticsearchQueryMethod queryMethod,
ElasticsearchOperations elasticsearchOperations) { ElasticsearchOperations elasticsearchOperations) {
this.queryMethod = queryMethod; this.queryMethod = queryMethod;
this.elasticsearchOperations = elasticsearchOperations; this.elasticsearchOperations = elasticsearchOperations;
} }

View File

@ -43,8 +43,6 @@ import org.springframework.util.ClassUtils;
*/ */
public class ElasticsearchPartQuery extends AbstractElasticsearchRepositoryQuery { public class ElasticsearchPartQuery extends AbstractElasticsearchRepositoryQuery {
private static final int DEFAULT_STREAM_BATCH_SIZE = 500;
private final PartTree tree; private final PartTree tree;
private final ElasticsearchConverter elasticsearchConverter; private final ElasticsearchConverter elasticsearchConverter;
private final MappingContext<?, ElasticsearchPersistentProperty> mappingContext; private final MappingContext<?, ElasticsearchPersistentProperty> mappingContext;

View File

@ -19,6 +19,7 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import org.springframework.core.convert.support.GenericConversionService; import org.springframework.core.convert.support.GenericConversionService;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.elasticsearch.core.ElasticsearchOperations; import org.springframework.data.elasticsearch.core.ElasticsearchOperations;
import org.springframework.data.elasticsearch.core.SearchHitSupport; import org.springframework.data.elasticsearch.core.SearchHitSupport;
import org.springframework.data.elasticsearch.core.SearchHits; import org.springframework.data.elasticsearch.core.SearchHits;
@ -26,6 +27,7 @@ import org.springframework.data.elasticsearch.core.convert.DateTimeConverters;
import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates;
import org.springframework.data.elasticsearch.core.query.StringQuery; import org.springframework.data.elasticsearch.core.query.StringQuery;
import org.springframework.data.repository.query.ParametersParameterAccessor; import org.springframework.data.repository.query.ParametersParameterAccessor;
import org.springframework.data.util.StreamUtils;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.NumberUtils; import org.springframework.util.NumberUtils;
@ -88,6 +90,13 @@ public class ElasticsearchStringQuery extends AbstractElasticsearchRepositoryQue
stringQuery.setPageable(accessor.getPageable()); stringQuery.setPageable(accessor.getPageable());
SearchHits<?> searchHits = elasticsearchOperations.search(stringQuery, clazz, index); SearchHits<?> searchHits = elasticsearchOperations.search(stringQuery, clazz, index);
result = SearchHitSupport.page(searchHits, stringQuery.getPageable()); result = SearchHitSupport.page(searchHits, stringQuery.getPageable());
} else if (queryMethod.isStreamQuery()) {
if (accessor.getPageable().isUnpaged()) {
stringQuery.setPageable(PageRequest.of(0, DEFAULT_STREAM_BATCH_SIZE));
} else {
stringQuery.setPageable(accessor.getPageable());
}
result = StreamUtils.createStreamFromIterator(elasticsearchOperations.searchForStream(stringQuery, clazz, index));
} else if (queryMethod.isCollectionQuery()) { } else if (queryMethod.isCollectionQuery()) {
if (accessor.getPageable().isPaged()) { if (accessor.getPageable().isPaged()) {
stringQuery.setPageable(accessor.getPageable()); stringQuery.setPageable(accessor.getPageable());

View File

@ -1544,6 +1544,28 @@ public abstract class CustomMethodRepositoryBaseTests {
return entities; return entities;
} }
@Test // DATAES-891
void shouldStreamEntitiesWithQueryAnnotatedMethod() {
List<SampleEntity> entities = createSampleEntities("abc", 20);
repository.saveAll(entities);
Stream<SampleEntity> stream = streamingRepository.streamEntitiesByType("abc");
long count = stream.peek(sampleEntity -> assertThat(sampleEntity).isInstanceOf(SampleEntity.class)).count();
assertThat(count).isEqualTo(20);
}
@Test // DATAES-891
void shouldStreamSearchHitsWithQueryAnnotatedMethod() {
List<SampleEntity> entities = createSampleEntities("abc", 20);
repository.saveAll(entities);
Stream<SearchHit<SampleEntity>> stream = streamingRepository.streamSearchHitsByType("abc");
long count = stream.peek(sampleEntity -> assertThat(sampleEntity).isInstanceOf(SearchHit.class)).count();
assertThat(count).isEqualTo(20);
}
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
@ -1687,5 +1709,12 @@ public abstract class CustomMethodRepositoryBaseTests {
Stream<SampleEntity> findByType(String type); Stream<SampleEntity> findByType(String type);
Stream<SampleEntity> findByType(String type, Pageable pageable); Stream<SampleEntity> findByType(String type, Pageable pageable);
@Query("{\"bool\": {\"must\": [{\"term\": {\"type\": \"?0\"}}]}}")
Stream<SampleEntity> streamEntitiesByType(String type);
@Query("{\"bool\": {\"must\": [{\"term\": {\"type\": \"?0\"}}]}}")
Stream<SearchHit<SampleEntity>> streamSearchHitsByType(String type);
} }
} }