DATAES-831 - SearchOperations.searchForStream does not use requested maxResults.

Original PR: #459

(cherry picked from commit 506f79a45aa93ad5787b25d807de5e5970bf0ea3)
This commit is contained in:
Peter-Josef Meisch 2020-05-17 10:49:50 +02:00
parent 1cee4057d9
commit e7110c14ab
6 changed files with 120 additions and 46 deletions

View File

@ -258,7 +258,11 @@ public abstract class AbstractElasticsearchTemplate implements ElasticsearchOper
long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis(); long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis();
// noinspection ConstantConditions
int maxCount = query.isLimiting() ? query.getMaxResults() : 0;
return StreamQueries.streamResults( // return StreamQueries.streamResults( //
maxCount, //
searchScrollStart(scrollTimeInMillis, query, clazz, index), // searchScrollStart(scrollTimeInMillis, query, clazz, index), //
scrollId -> searchScrollContinue(scrollId, scrollTimeInMillis, clazz, index), // scrollId -> searchScrollContinue(scrollId, scrollTimeInMillis, clazz, index), //
this::searchScrollClear); this::searchScrollClear);

View File

@ -40,6 +40,8 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.search.suggest.SuggestBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter; import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter;
import org.springframework.data.elasticsearch.core.document.DocumentAdapters; import org.springframework.data.elasticsearch.core.document.DocumentAdapters;
import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse; import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse;
@ -88,6 +90,8 @@ import org.springframework.util.Assert;
*/ */
public class ElasticsearchRestTemplate extends AbstractElasticsearchTemplate { public class ElasticsearchRestTemplate extends AbstractElasticsearchTemplate {
private static final Logger LOGGER = LoggerFactory.getLogger(ElasticsearchRestTemplate.class);
private RestHighLevelClient client; private RestHighLevelClient client;
private ElasticsearchExceptionTranslator exceptionTranslator; private ElasticsearchExceptionTranslator exceptionTranslator;
@ -300,9 +304,13 @@ public class ElasticsearchRestTemplate extends AbstractElasticsearchTemplate {
@Override @Override
public void searchScrollClear(List<String> scrollIds) { public void searchScrollClear(List<String> scrollIds) {
ClearScrollRequest request = new ClearScrollRequest(); try {
request.scrollIds(scrollIds); ClearScrollRequest request = new ClearScrollRequest();
execute(client -> client.clearScroll(request, RequestOptions.DEFAULT)); request.scrollIds(scrollIds);
execute(client -> client.clearScroll(request, RequestOptions.DEFAULT));
} catch (Exception e) {
LOGGER.warn("Could not clear scroll: {}", e.getMessage());
}
} }
@Override @Override

View File

@ -86,6 +86,7 @@ import org.springframework.util.Assert;
public class ElasticsearchTemplate extends AbstractElasticsearchTemplate { public class ElasticsearchTemplate extends AbstractElasticsearchTemplate {
private static final Logger QUERY_LOGGER = LoggerFactory private static final Logger QUERY_LOGGER = LoggerFactory
.getLogger("org.springframework.data.elasticsearch.core.QUERY"); .getLogger("org.springframework.data.elasticsearch.core.QUERY");
private static final Logger LOGGER = LoggerFactory.getLogger(ElasticsearchTemplate.class);
private Client client; private Client client;
@Nullable private String searchTimeout; @Nullable private String searchTimeout;
@ -322,7 +323,11 @@ public class ElasticsearchTemplate extends AbstractElasticsearchTemplate {
@Override @Override
public void searchScrollClear(List<String> scrollIds) { public void searchScrollClear(List<String> scrollIds) {
client.prepareClearScroll().setScrollIds(scrollIds).execute().actionGet(); try {
client.prepareClearScroll().setScrollIds(scrollIds).execute().actionGet();
} catch (Exception e) {
LOGGER.warn("Could not clear scroll: {}", e.getMessage());
}
} }
@Override @Override

View File

@ -18,6 +18,7 @@ package org.springframework.data.elasticsearch.core;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
@ -38,13 +39,15 @@ abstract class StreamQueries {
/** /**
* Stream query results using {@link SearchScrollHits}. * Stream query results using {@link SearchScrollHits}.
* *
* @param maxCount the maximum number of entities to return, a value of 0 means that all available entities are
* returned
* @param searchHits the initial hits * @param searchHits the initial hits
* @param continueScrollFunction function to continue scrolling applies to the current scrollId. * @param continueScrollFunction function to continue scrolling applies to the current scrollId.
* @param clearScrollConsumer consumer to clear the scroll context by accepting the scrollIds to clear. * @param clearScrollConsumer consumer to clear the scroll context by accepting the scrollIds to clear.
* @param <T> * @param <T> the entity type
* @return the {@link SearchHitsIterator}. * @return the {@link SearchHitsIterator}.
*/ */
static <T> SearchHitsIterator<T> streamResults(SearchScrollHits<T> searchHits, static <T> SearchHitsIterator<T> streamResults(int maxCount, SearchScrollHits<T> searchHits,
Function<String, SearchScrollHits<T>> continueScrollFunction, Consumer<List<String>> clearScrollConsumer) { Function<String, SearchScrollHits<T>> continueScrollFunction, Consumer<List<String>> clearScrollConsumer) {
Assert.notNull(searchHits, "searchHits must not be null."); Assert.notNull(searchHits, "searchHits must not be null.");
@ -59,20 +62,14 @@ abstract class StreamQueries {
return new SearchHitsIterator<T>() { return new SearchHitsIterator<T>() {
// As we couldn't retrieve single result with scroll, store current hits. private volatile AtomicInteger currentCount = new AtomicInteger();
private volatile Iterator<SearchHit<T>> scrollHits = searchHits.iterator(); private volatile Iterator<SearchHit<T>> currentScrollHits = searchHits.iterator();
private volatile boolean continueScroll = scrollHits.hasNext(); private volatile boolean continueScroll = currentScrollHits.hasNext();
private volatile ScrollState scrollState = new ScrollState(searchHits.getScrollId()); private volatile ScrollState scrollState = new ScrollState(searchHits.getScrollId());
@Override @Override
public void close() { public void close() {
clearScrollConsumer.accept(scrollState.getScrollIds());
try {
clearScrollConsumer.accept(scrollState.getScrollIds());
} finally {
scrollHits = null;
scrollState = null;
}
} }
@Override @Override
@ -99,24 +96,25 @@ abstract class StreamQueries {
@Override @Override
public boolean hasNext() { public boolean hasNext() {
if (!continueScroll) { if (!continueScroll || (maxCount > 0 && currentCount.get() >= maxCount)) {
return false; return false;
} }
if (!scrollHits.hasNext()) { if (!currentScrollHits.hasNext()) {
SearchScrollHits<T> nextPage = continueScrollFunction.apply(scrollState.getScrollId()); SearchScrollHits<T> nextPage = continueScrollFunction.apply(scrollState.getScrollId());
scrollHits = nextPage.iterator(); currentScrollHits = nextPage.iterator();
scrollState.updateScrollId(nextPage.getScrollId()); scrollState.updateScrollId(nextPage.getScrollId());
continueScroll = scrollHits.hasNext(); continueScroll = currentScrollHits.hasNext();
} }
return scrollHits.hasNext(); return currentScrollHits.hasNext();
} }
@Override @Override
public SearchHit<T> next() { public SearchHit<T> next() {
if (hasNext()) { if (hasNext()) {
return scrollHits.next(); currentCount.incrementAndGet();
return currentScrollHits.next();
} }
throw new NoSuchElementException(); throw new NoSuchElementException();
} }

View File

@ -77,7 +77,7 @@ import org.springframework.data.elasticsearch.annotations.ScriptedField;
import org.springframework.data.elasticsearch.core.geo.GeoPoint; import org.springframework.data.elasticsearch.core.geo.GeoPoint;
import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates;
import org.springframework.data.elasticsearch.core.query.*; import org.springframework.data.elasticsearch.core.query.*;
import org.springframework.data.util.CloseableIterator; import org.springframework.data.util.StreamUtils;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
/** /**
@ -1298,27 +1298,33 @@ public abstract class ElasticsearchTemplateTests {
assertThat(sampleEntities).hasSize(30); assertThat(sampleEntities).hasSize(30);
} }
@Test // DATAES-167 @Test // DATAES-167, DATAES-831
public void shouldReturnResultsWithStreamForGivenCriteriaQuery() { public void shouldReturnAllResultsWithStreamForGivenCriteriaQuery() {
// given operations.bulkIndex(createSampleEntitiesWithMessage("Test message", 30), index);
List<IndexQuery> entities = createSampleEntitiesWithMessage("Test message", 30);
// when
operations.bulkIndex(entities, index);
indexOperations.refresh(); indexOperations.refresh();
// then
CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria()); CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria());
criteriaQuery.setPageable(PageRequest.of(0, 10)); criteriaQuery.setPageable(PageRequest.of(0, 10));
CloseableIterator<SearchHit<SampleEntity>> stream = operations.searchForStream(criteriaQuery, SampleEntity.class, long count = StreamUtils
index); .createStreamFromIterator(operations.searchForStream(criteriaQuery, SampleEntity.class, index)).count();
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (stream.hasNext()) { assertThat(count).isEqualTo(30);
sampleEntities.add(stream.next()); }
}
assertThat(sampleEntities).hasSize(30); @Test // DATAES-831
void shouldLimitStreamResultToRequestedSize() {
operations.bulkIndex(createSampleEntitiesWithMessage("Test message", 30), index);
indexOperations.refresh();
CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria());
criteriaQuery.setMaxResults(10);
long count = StreamUtils
.createStreamFromIterator(operations.searchForStream(criteriaQuery, SampleEntity.class, index)).count();
assertThat(count).isEqualTo(10);
} }
private static List<IndexQuery> createSampleEntitiesWithMessage(String message, int numberOfEntities) { private static List<IndexQuery> createSampleEntitiesWithMessage(String message, int numberOfEntities) {
@ -3128,8 +3134,8 @@ public abstract class ElasticsearchTemplateTests {
operations.refresh(OptimisticEntity.class); operations.refresh(OptimisticEntity.class);
List<Query> queries = singletonList(queryForOne(saved.getId())); List<Query> queries = singletonList(queryForOne(saved.getId()));
List<SearchHits<OptimisticEntity>> retrievedHits = operations.multiSearch(queries, List<SearchHits<OptimisticEntity>> retrievedHits = operations.multiSearch(queries, OptimisticEntity.class,
OptimisticEntity.class, operations.getIndexCoordinatesFor(OptimisticEntity.class)); operations.getIndexCoordinatesFor(OptimisticEntity.class));
OptimisticEntity retrieved = retrievedHits.get(0).getSearchHit(0).getContent(); OptimisticEntity retrieved = retrievedHits.get(0).getSearchHit(0).getContent();
assertThatSeqNoPrimaryTermIsFilled(retrieved); assertThatSeqNoPrimaryTermIsFilled(retrieved);
@ -3162,8 +3168,7 @@ public abstract class ElasticsearchTemplateTests {
operations.save(forEdit1); operations.save(forEdit1);
forEdit2.setMessage("It'll be great"); forEdit2.setMessage("It'll be great");
assertThatThrownBy(() -> operations.save(forEdit2)) assertThatThrownBy(() -> operations.save(forEdit2)).isInstanceOf(OptimisticLockingFailureException.class);
.isInstanceOf(OptimisticLockingFailureException.class);
} }
@Test // DATAES-799 @Test // DATAES-799
@ -3179,8 +3184,7 @@ public abstract class ElasticsearchTemplateTests {
operations.save(forEdit1); operations.save(forEdit1);
forEdit2.setMessage("It'll be great"); forEdit2.setMessage("It'll be great");
assertThatThrownBy(() -> operations.save(forEdit2)) assertThatThrownBy(() -> operations.save(forEdit2)).isInstanceOf(OptimisticLockingFailureException.class);
.isInstanceOf(OptimisticLockingFailureException.class);
} }
@Test // DATAES-799 @Test // DATAES-799

View File

@ -25,6 +25,7 @@ import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.data.util.StreamUtils;
/** /**
* @author Sascha Woo * @author Sascha Woo
@ -45,6 +46,7 @@ public class StreamQueriesTest {
// when // when
SearchHitsIterator<String> iterator = StreamQueries.streamResults( // SearchHitsIterator<String> iterator = StreamQueries.streamResults( //
0, //
searchHits, // searchHits, //
scrollId -> newSearchScrollHits(Collections.emptyList(), scrollId), // scrollId -> newSearchScrollHits(Collections.emptyList(), scrollId), //
scrollIds -> clearScrollCalled.set(true)); scrollIds -> clearScrollCalled.set(true));
@ -70,6 +72,7 @@ public class StreamQueriesTest {
// when // when
SearchHitsIterator<String> iterator = StreamQueries.streamResults( // SearchHitsIterator<String> iterator = StreamQueries.streamResults( //
0, //
searchHits, // searchHits, //
scrollId -> newSearchScrollHits(Collections.emptyList(), scrollId), // scrollId -> newSearchScrollHits(Collections.emptyList(), scrollId), //
scrollId -> {}); scrollId -> {});
@ -90,10 +93,12 @@ public class StreamQueriesTest {
Collections.singletonList(new SearchHit<String>(null, 0, null, null, "one")), "s-2"); Collections.singletonList(new SearchHit<String>(null, 0, null, null, "one")), "s-2");
SearchScrollHits<String> searchHits4 = newSearchScrollHits(Collections.emptyList(), "s-3"); SearchScrollHits<String> searchHits4 = newSearchScrollHits(Collections.emptyList(), "s-3");
Iterator<SearchScrollHits<String>> searchScrollHitsIterator = Arrays.asList(searchHits1, searchHits2, searchHits3,searchHits4).iterator(); Iterator<SearchScrollHits<String>> searchScrollHitsIterator = Arrays
.asList(searchHits1, searchHits2, searchHits3, searchHits4).iterator();
List<String> clearedScrollIds = new ArrayList<>(); List<String> clearedScrollIds = new ArrayList<>();
SearchHitsIterator<String> iterator = StreamQueries.streamResults( // SearchHitsIterator<String> iterator = StreamQueries.streamResults( //
0, //
searchScrollHitsIterator.next(), // searchScrollHitsIterator.next(), //
scrollId -> searchScrollHitsIterator.next(), // scrollId -> searchScrollHitsIterator.next(), //
scrollIds -> clearedScrollIds.addAll(scrollIds)); scrollIds -> clearedScrollIds.addAll(scrollIds));
@ -106,6 +111,56 @@ public class StreamQueriesTest {
assertThat(clearedScrollIds).isEqualTo(Arrays.asList("s-1", "s-2", "s-3")); assertThat(clearedScrollIds).isEqualTo(Arrays.asList("s-1", "s-2", "s-3"));
} }
@Test // DATAES-831
void shouldReturnAllForRequestedSizeOf0() {
SearchScrollHits<String> searchHits1 = newSearchScrollHits(
Collections.singletonList(new SearchHit<String>(null, 0, null, null, "one")), "s-1");
SearchScrollHits<String> searchHits2 = newSearchScrollHits(
Collections.singletonList(new SearchHit<String>(null, 0, null, null, "one")), "s-2");
SearchScrollHits<String> searchHits3 = newSearchScrollHits(
Collections.singletonList(new SearchHit<String>(null, 0, null, null, "one")), "s-2");
SearchScrollHits<String> searchHits4 = newSearchScrollHits(Collections.emptyList(), "s-3");
Iterator<SearchScrollHits<String>> searchScrollHitsIterator = Arrays
.asList(searchHits1, searchHits2, searchHits3, searchHits4).iterator();
SearchHitsIterator<String> iterator = StreamQueries.streamResults( //
0, //
searchScrollHitsIterator.next(), //
scrollId -> searchScrollHitsIterator.next(), //
scrollIds -> {});
long count = StreamUtils.createStreamFromIterator(iterator).count();
assertThat(count).isEqualTo(3);
}
@Test // DATAES-831
void shouldOnlyReturnRequestedCount() {
SearchScrollHits<String> searchHits1 = newSearchScrollHits(
Collections.singletonList(new SearchHit<String>(null, 0, null, null, "one")), "s-1");
SearchScrollHits<String> searchHits2 = newSearchScrollHits(
Collections.singletonList(new SearchHit<String>(null, 0, null, null, "one")), "s-2");
SearchScrollHits<String> searchHits3 = newSearchScrollHits(
Collections.singletonList(new SearchHit<String>(null, 0, null, null, "one")), "s-2");
SearchScrollHits<String> searchHits4 = newSearchScrollHits(Collections.emptyList(), "s-3");
Iterator<SearchScrollHits<String>> searchScrollHitsIterator = Arrays
.asList(searchHits1, searchHits2, searchHits3, searchHits4).iterator();
SearchHitsIterator<String> iterator = StreamQueries.streamResults( //
2, //
searchScrollHitsIterator.next(), //
scrollId -> searchScrollHitsIterator.next(), //
scrollIds -> {});
long count = StreamUtils.createStreamFromIterator(iterator).count();
assertThat(count).isEqualTo(2);
}
private SearchScrollHits<String> newSearchScrollHits(List<SearchHit<String>> hits, String scrollId) { private SearchScrollHits<String> newSearchScrollHits(List<SearchHit<String>> hits, String scrollId) {
return new SearchHitsImpl<String>(hits.size(), TotalHitsRelation.EQUAL_TO, 0, scrollId, hits, null); return new SearchHitsImpl<String>(hits.size(), TotalHitsRelation.EQUAL_TO, 0, scrollId, hits, null);
} }