DATAES-766 - Replace CloseableIterator with SearchHitsIterator in stream operations.

Original pull request: #407
This commit is contained in:
xhaggi 2020-03-20 14:09:08 +01:00
parent db28d93676
commit f354f986ca
16 changed files with 518 additions and 276 deletions

View File

@ -208,20 +208,25 @@ public abstract class AbstractElasticsearchTemplate implements ElasticsearchOper
@Override @Override
public <T> CloseableIterator<T> stream(Query query, Class<T> clazz, IndexCoordinates index) { public <T> CloseableIterator<T> stream(Query query, Class<T> clazz, IndexCoordinates index) {
long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis(); long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis();
return (CloseableIterator<T>) SearchHitSupport.unwrapSearchHits(searchForStream(query, clazz, index)); return (CloseableIterator<T>) SearchHitSupport.unwrapSearchHits(searchForStream(query, clazz, index));
} }
@Override @Override
public <T> CloseableIterator<SearchHit<T>> searchForStream(Query query, Class<T> clazz) { public <T> SearchHitsIterator<T> searchForStream(Query query, Class<T> clazz) {
return searchForStream(query, clazz, getIndexCoordinatesFor(clazz)); return searchForStream(query, clazz, getIndexCoordinatesFor(clazz));
} }
@Override @Override
public <T> CloseableIterator<SearchHit<T>> searchForStream(Query query, Class<T> clazz, IndexCoordinates index) { public <T> SearchHitsIterator<T> searchForStream(Query query, Class<T> clazz, IndexCoordinates index) {
long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis(); long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis();
return StreamQueries.streamResults(searchScrollStart(scrollTimeInMillis, query, clazz, index),
scrollId -> searchScrollContinue(scrollId, scrollTimeInMillis, clazz), this::searchScrollClear); return StreamQueries.streamResults( //
searchScrollStart(scrollTimeInMillis, query, clazz, index), //
scrollId -> searchScrollContinue(scrollId, scrollTimeInMillis, clazz), //
this::searchScrollClear);
} }
@Override @Override
@ -283,13 +288,13 @@ public abstract class AbstractElasticsearchTemplate implements ElasticsearchOper
/* /*
* internal use only, not for public API * internal use only, not for public API
*/ */
abstract protected <T> ScrolledPage<SearchHit<T>> searchScrollStart(long scrollTimeInMillis, Query query, abstract protected <T> SearchScrollHits<T> searchScrollStart(long scrollTimeInMillis, Query query,
Class<T> clazz, IndexCoordinates index); Class<T> clazz, IndexCoordinates index);
/* /*
* internal use only, not for public API * internal use only, not for public API
*/ */
abstract protected <T> ScrolledPage<SearchHit<T>> searchScrollContinue(@Nullable String scrollId, abstract protected <T> SearchScrollHits<T> searchScrollContinue(@Nullable String scrollId,
long scrollTimeInMillis, Class<T> clazz); long scrollTimeInMillis, Class<T> clazz);
/* /*

View File

@ -39,7 +39,6 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.search.suggest.SuggestBuilder;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter; import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter;
import org.springframework.data.elasticsearch.core.document.DocumentAdapters; import org.springframework.data.elasticsearch.core.document.DocumentAdapters;
import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse; import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse;
@ -257,24 +256,28 @@ public class ElasticsearchRestTemplate extends AbstractElasticsearchTemplate {
} }
@Override @Override
public <T> ScrolledPage<SearchHit<T>> searchScrollStart(long scrollTimeInMillis, Query query, Class<T> clazz, public <T> SearchScrollHits<T> searchScrollStart(long scrollTimeInMillis, Query query, Class<T> clazz,
IndexCoordinates index) { IndexCoordinates index) {
Assert.notNull(query.getPageable(), "Query.pageable is required for scan & scroll"); Assert.notNull(query.getPageable(), "pageable of query must not be null.");
SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index); SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index);
searchRequest.scroll(TimeValue.timeValueMillis(scrollTimeInMillis)); searchRequest.scroll(TimeValue.timeValueMillis(scrollTimeInMillis));
SearchResponse result = execute(client -> client.search(searchRequest, RequestOptions.DEFAULT));
return elasticsearchConverter.mapResults(SearchDocumentResponse.from(result), clazz, null); SearchResponse response = execute(client -> client.search(searchRequest, RequestOptions.DEFAULT));
return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response));
} }
@Override @Override
public <T> ScrolledPage<SearchHit<T>> searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, public <T> SearchScrollHits<T> searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, Class<T> clazz) {
Class<T> clazz) {
SearchScrollRequest request = new SearchScrollRequest(scrollId); SearchScrollRequest request = new SearchScrollRequest(scrollId);
request.scroll(TimeValue.timeValueMillis(scrollTimeInMillis)); request.scroll(TimeValue.timeValueMillis(scrollTimeInMillis));
SearchResponse response = execute(client -> client.searchScroll(request, RequestOptions.DEFAULT));
return elasticsearchConverter.mapResults(SearchDocumentResponse.from(response), clazz, Pageable.unpaged()); SearchResponse response = execute(client -> client.scroll(request, RequestOptions.DEFAULT));
return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response));
} }
@Override @Override

View File

@ -27,6 +27,7 @@ import org.elasticsearch.action.search.MultiSearchRequest;
import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequestBuilder;
import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
@ -260,22 +261,32 @@ public class ElasticsearchTemplate extends AbstractElasticsearchTemplate {
} }
@Override @Override
public <T> ScrolledPage<SearchHit<T>> searchScrollStart(long scrollTimeInMillis, Query query, Class<T> clazz, public <T> SearchScrollHits<T> searchScrollStart(long scrollTimeInMillis, Query query, Class<T> clazz,
IndexCoordinates index) { IndexCoordinates index) {
Assert.notNull(query.getPageable(), "Query.pageable is required for scan & scroll");
SearchRequestBuilder searchRequestBuilder = requestFactory.searchRequestBuilder(client, query, clazz, index); Assert.notNull(query.getPageable(), "pageable of query must not be null.");
searchRequestBuilder.setScroll(TimeValue.timeValueMillis(scrollTimeInMillis));
SearchResponse response = getSearchResponse(searchRequestBuilder); ActionFuture<SearchResponse> action = requestFactory //
return elasticsearchConverter.mapResults(SearchDocumentResponse.from(response), clazz, null); .searchRequestBuilder(client, query, clazz, index) //
.setScroll(TimeValue.timeValueMillis(scrollTimeInMillis)) //
.execute();
SearchResponse response = getSearchResponseWithTimeout(action);
return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response));
} }
@Override @Override
public <T> ScrolledPage<SearchHit<T>> searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, public <T> SearchScrollHits<T> searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, Class<T> clazz) {
Class<T> clazz) {
SearchResponse response = getSearchResponseWithTimeout( ActionFuture<SearchResponse> action = client //
client.prepareSearchScroll(scrollId).setScroll(TimeValue.timeValueMillis(scrollTimeInMillis)).execute()); .prepareSearchScroll(scrollId) //
return elasticsearchConverter.mapResults(SearchDocumentResponse.from(response), clazz, Pageable.unpaged()); .setScroll(TimeValue.timeValueMillis(scrollTimeInMillis)) //
.execute();
SearchResponse response = getSearchResponseWithTimeout(action);
return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response));
} }
@Override @Override

View File

@ -2,13 +2,14 @@
package org.springframework.data.elasticsearch.core; package org.springframework.data.elasticsearch.core;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
import org.springframework.lang.Nullable;
/** /**
* @author Artur Konczak * @author Artur Konczak
* @author Peter-Josef Meisch * @author Peter-Josef Meisch
* @author Sascha Woo * @author Sascha Woo
* @deprecated since 4.0, will be removed in a future version.
*/ */
@Deprecated
public interface ScrolledPage<T> extends Page<T> { public interface ScrolledPage<T> extends Page<T> {
String getScrollId(); String getScrollId();

View File

@ -32,6 +32,7 @@ import org.springframework.lang.Nullable;
* Utility class with helper methods for working with {@link SearchHit}. * Utility class with helper methods for working with {@link SearchHit}.
* *
* @author Peter-Josef Meisch * @author Peter-Josef Meisch
* @author Sascha Woo
* @since 4.0 * @since 4.0
*/ */
public final class SearchHitSupport { public final class SearchHitSupport {
@ -95,10 +96,17 @@ public final class SearchHitSupport {
* @param searchHits, must not be {@literal null}. * @param searchHits, must not be {@literal null}.
* @param pageable, must not be {@literal null}. * @param pageable, must not be {@literal null}.
* @return the created Page * @return the created Page
* @deprecated since 4.0, will be removed in a future version.
*/ */
@Deprecated
public static <T> AggregatedPage<SearchHit<T>> page(SearchHits<T> searchHits, Pageable pageable) { public static <T> AggregatedPage<SearchHit<T>> page(SearchHits<T> searchHits, Pageable pageable) {
return new AggregatedPageImpl<>(searchHits.getSearchHits(), pageable, searchHits.getTotalHits(), return new AggregatedPageImpl<>( //
searchHits.getAggregations(), searchHits.getScrollId(), searchHits.getMaxScore()); searchHits.getSearchHits(), //
pageable, //
searchHits.getTotalHits(), //
searchHits.getAggregations(), //
null, //
searchHits.getMaxScore());
} }
public static <T> SearchPage<T> searchPageFor(SearchHits<T> searchHits, @Nullable Pageable pageable) { public static <T> SearchPage<T> searchPageFor(SearchHits<T> searchHits, @Nullable Pageable pageable) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2019-2020 the original author or authors. * Copyright 2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,143 +15,74 @@
*/ */
package org.springframework.data.elasticsearch.core; package org.springframework.data.elasticsearch.core;
import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.springframework.data.util.Streamable; import org.springframework.data.util.Streamable;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/** /**
* Encapsulates a list of {@link SearchHit}s with additional information from the search. * Encapsulates a list of {@link SearchHit}s with additional information from the search.
* *
* @param <T> the result data class. * @param <T> the result data class.
* @author Peter-Josef Meisch * @author Sascha Woo
* @since 4.0 * @since 4.0
*/ */
public class SearchHits<T> implements Streamable<SearchHit<T>> { public interface SearchHits<T> extends Streamable<SearchHit<T>> {
private final long totalHits;
private final TotalHitsRelation totalHitsRelation;
private final float maxScore;
private final String scrollId;
private final List<? extends SearchHit<T>> searchHits;
private final Aggregations aggregations;
/**
* @param totalHits the number of total hits for the search
* @param totalHitsRelation the relation {@see TotalHitsRelation}, must not be {@literal null}
* @param maxScore the maximum score
* @param scrollId the scroll id if available
* @param searchHits must not be {@literal null}
* @param aggregations the aggregations if available
*/
public SearchHits(long totalHits, TotalHitsRelation totalHitsRelation, float maxScore, @Nullable String scrollId,
List<? extends SearchHit<T>> searchHits, @Nullable Aggregations aggregations) {
Assert.notNull(searchHits, "searchHits must not be null");
this.totalHits = totalHits;
this.totalHitsRelation = totalHitsRelation;
this.maxScore = maxScore;
this.scrollId = scrollId;
this.searchHits = searchHits;
this.aggregations = aggregations;
}
@SuppressWarnings("unchecked")
@Override
public Iterator<SearchHit<T>> iterator() {
return (Iterator<SearchHit<T>>) searchHits.iterator();
}
// region getter
/**
* @return the number of total hits.
*/
public long getTotalHits() {
return totalHits;
}
/**
* @return the relation for the total hits
*/
public TotalHitsRelation getTotalHitsRelation() {
return totalHitsRelation;
}
/**
* @return the maximum score
*/
public float getMaxScore() {
return maxScore;
}
/**
* @return the scroll id
*/
@Nullable
public String getScrollId() {
return scrollId;
}
/**
* @return the contained {@link SearchHit}s.
*/
public List<SearchHit<T>> getSearchHits() {
return Collections.unmodifiableList(searchHits);
}
// endregion
// region SearchHit access
/**
* @param index position in List.
* @return the {@link SearchHit} at position {index}
* @throws IndexOutOfBoundsException on invalid index
*/
public SearchHit<T> getSearchHit(int index) {
return searchHits.get(index);
}
// endregion
@Override
public String toString() {
return "SearchHits{" + //
"totalHits=" + totalHits + //
", totalHitsRelation=" + totalHitsRelation + //
", maxScore=" + maxScore + //
", scrollId='" + scrollId + '\'' + //
", searchHits={" + searchHits.size() + " elements}" + //
", aggregations=" + aggregations + //
'}';
}
// region aggregations
/**
* @return true if aggregations are available
*/
public boolean hasAggregations() {
return aggregations != null;
}
/** /**
* @return the aggregations. * @return the aggregations.
*/ */
@Nullable @Nullable
public Aggregations getAggregations() { Aggregations getAggregations();
return aggregations;
}
// endregion
/** /**
* Enum to represent the relation that Elasticsearch returns for the totalHits value {@see <a href= * @return the maximum score
* "https://www.elastic.co/guide/en/elasticsearch/reference/7.5/search-request-body.html#request-body-search-track-total-hits">Ekasticsearch
* docs</a>}
*/ */
public enum TotalHitsRelation { float getMaxScore();
EQUAL_TO, GREATER_THAN_OR_EQUAL_TO
/**
* @param index position in List.
* @return the {@link SearchHit} at position {index}
* @throws IndexOutOfBoundsException on invalid index
*/
SearchHit<T> getSearchHit(int index);
/**
* @return the contained {@link SearchHit}s.
*/
List<SearchHit<T>> getSearchHits();
/**
* @return the number of total hits.
*/
long getTotalHits();
/**
* @return the relation for the total hits
*/
TotalHitsRelation getTotalHitsRelation();
/**
* @return true if aggregations are available
*/
default boolean hasAggregations() {
return getAggregations() != null;
} }
/**
* @return whether the {@link SearchHits} has search hits.
*/
default boolean hasSearchHits() {
return !getSearchHits().isEmpty();
}
/**
* @return an iterator for {@link SearchHit}
*/
default Iterator<SearchHit<T>> iterator() {
return getSearchHits().iterator();
}
} }

View File

@ -0,0 +1,117 @@
/*
* Copyright 2019-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.core;
import java.util.Collections;
import java.util.List;
import org.elasticsearch.search.aggregations.Aggregations;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* Basic implementation of {@link SearchScrollHits}
*
* @param <T> the result data class.
* @author Peter-Josef Meisch
* @author Sascha Woo
* @since 4.0
*/
public class SearchHitsImpl<T> implements SearchScrollHits<T> {
private final long totalHits;
private final TotalHitsRelation totalHitsRelation;
private final float maxScore;
private final String scrollId;
private final List<? extends SearchHit<T>> searchHits;
private final Aggregations aggregations;
/**
* @param totalHits the number of total hits for the search
* @param totalHitsRelation the relation {@see TotalHitsRelation}, must not be {@literal null}
* @param maxScore the maximum score
* @param scrollId the scroll id if available
* @param searchHits must not be {@literal null}
* @param aggregations the aggregations if available
*/
public SearchHitsImpl(long totalHits, TotalHitsRelation totalHitsRelation, float maxScore, @Nullable String scrollId,
List<? extends SearchHit<T>> searchHits, @Nullable Aggregations aggregations) {
Assert.notNull(searchHits, "searchHits must not be null");
this.totalHits = totalHits;
this.totalHitsRelation = totalHitsRelation;
this.maxScore = maxScore;
this.scrollId = scrollId;
this.searchHits = searchHits;
this.aggregations = aggregations;
}
// region getter
@Override
public long getTotalHits() {
return totalHits;
}
@Override
public TotalHitsRelation getTotalHitsRelation() {
return totalHitsRelation;
}
@Override
public float getMaxScore() {
return maxScore;
}
@Override
@Nullable
public String getScrollId() {
return scrollId;
}
@Override
public List<SearchHit<T>> getSearchHits() {
return Collections.unmodifiableList(searchHits);
}
// endregion
// region SearchHit access
@Override
public SearchHit<T> getSearchHit(int index) {
return searchHits.get(index);
}
// endregion
@Override
public String toString() {
return "SearchHits{" + //
"totalHits=" + totalHits + //
", totalHitsRelation=" + totalHitsRelation + //
", maxScore=" + maxScore + //
", scrollId='" + scrollId + '\'' + //
", searchHits={" + searchHits.size() + " elements}" + //
", aggregations=" + aggregations + //
'}';
}
// region aggregations
@Override
@Nullable
public Aggregations getAggregations() {
return aggregations;
}
// endregion
}

View File

@ -0,0 +1,60 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.core;
import org.elasticsearch.search.aggregations.Aggregations;
import org.springframework.data.util.CloseableIterator;
import org.springframework.lang.Nullable;
/**
* A {@link SearchHitsIterator} encapsulates {@link SearchHit} results that can be wrapped in a Java 8
* {@link java.util.stream.Stream}.
*
* @author Sascha Woo
* @param <T>
* @since 4.0
*/
public interface SearchHitsIterator<T> extends CloseableIterator<SearchHit<T>> {
/**
* @return the aggregations.
*/
@Nullable
Aggregations getAggregations();
/**
* @return the maximum score
*/
float getMaxScore();
/**
* @return the number of total hits.
*/
long getTotalHits();
/**
* @return the relation for the total hits
*/
TotalHitsRelation getTotalHitsRelation();
/**
* @return true if aggregations are available
*/
default boolean hasAggregations() {
return getAggregations() != null;
}
}

View File

@ -36,6 +36,7 @@ import org.springframework.lang.Nullable;
* APIs</a>. * APIs</a>.
* *
* @author Peter-Josef Meisch * @author Peter-Josef Meisch
* @author Sascha Woo
* @since 4.0 * @since 4.0
*/ */
public interface SearchOperations { public interface SearchOperations {
@ -155,8 +156,9 @@ public interface SearchOperations {
* @param query the query to execute * @param query the query to execute
* @param clazz the entity clazz used for property mapping * @param clazz the entity clazz used for property mapping
* @param index the index to run the query against * @param index the index to run the query against
* @return a {@link CloseableIterator} that wraps an Elasticsearch scroll context that needs to be closed in case of * * @return a {@link CloseableIterator} that wraps an Elasticsearch scroll context that needs to be closed. The
* error. * try-with-resources construct should be used to ensure that the close method is invoked after the operations
* are completed.
* @deprecated since 4.0, use {@link #searchForStream(Query, Class, IndexCoordinates)}. * @deprecated since 4.0, use {@link #searchForStream(Query, Class, IndexCoordinates)}.
*/ */
@Deprecated @Deprecated
@ -237,7 +239,6 @@ public interface SearchOperations {
return (AggregatedPage<T>) SearchHitSupport.unwrapSearchHits(aggregatedPage); return (AggregatedPage<T>) SearchHitSupport.unwrapSearchHits(aggregatedPage);
} }
// endregion // endregion
/** /**
@ -340,27 +341,29 @@ public interface SearchOperations {
<T> SearchHits<T> search(MoreLikeThisQuery query, Class<T> clazz, IndexCoordinates index); <T> SearchHits<T> search(MoreLikeThisQuery query, Class<T> clazz, IndexCoordinates index);
/** /**
* Executes the given {@link Query} against elasticsearch and return result as {@link CloseableIterator}. * Executes the given {@link Query} against elasticsearch and return result as {@link SearchHitsIterator}.
* <p> * <p>
* *
* @param <T> element return type * @param <T> element return type
* @param query the query to execute * @param query the query to execute
* @param clazz the entity clazz used for property mapping and index name extraction * @param clazz the entity clazz used for property mapping and index name extraction
* @return a {@link CloseableIterator} that wraps an Elasticsearch scroll context that needs to be closed in case of * * @return a {@link SearchHitsIterator} that wraps an Elasticsearch scroll context that needs to be closed. The
* error. * try-with-resources construct should be used to ensure that the close method is invoked after the operations
* are completed.
*/ */
<T> CloseableIterator<SearchHit<T>> searchForStream(Query query, Class<T> clazz); <T> SearchHitsIterator<T> searchForStream(Query query, Class<T> clazz);
/** /**
* Executes the given {@link Query} against elasticsearch and return result as {@link CloseableIterator}. * Executes the given {@link Query} against elasticsearch and return result as {@link SearchHitsIterator}.
* <p> * <p>
* *
* @param <T> element return type * @param <T> element return type
* @param query the query to execute * @param query the query to execute
* @param clazz the entity clazz used for property mapping * @param clazz the entity clazz used for property mapping
* @param index the index to run the query against * @param index the index to run the query against
* @return a {@link CloseableIterator} that wraps an Elasticsearch scroll context that needs to be closed in case of * * @return a {@link SearchHitsIterator} that wraps an Elasticsearch scroll context that needs to be closed. The
* error. * try-with-resources construct should be used to ensure that the close method is invoked after the operations
* are completed.
*/ */
<T> CloseableIterator<SearchHit<T>> searchForStream(Query query, Class<T> clazz, IndexCoordinates index); <T> SearchHitsIterator<T> searchForStream(Query query, Class<T> clazz, IndexCoordinates index);
} }

View File

@ -0,0 +1,34 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.core;
/**
* This interface is used to expose the current {@code scrollId} from the underlying scroll context.
* <p>
* Internal use only.
*
* @author Sascha Woo
* @param <T>
* @since 4.0
*/
public interface SearchScrollHits<T> extends SearchHits<T> {
/**
* @return the scroll id
*/
String getScrollId();
}

View File

@ -20,7 +20,8 @@ import java.util.NoSuchElementException;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import org.springframework.data.util.CloseableIterator; import org.elasticsearch.search.aggregations.Aggregations;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
@ -33,27 +34,32 @@ import org.springframework.util.Assert;
abstract class StreamQueries { abstract class StreamQueries {
/** /**
* Stream query results using {@link ScrolledPage}. * Stream query results using {@link SearchScrollHits}.
* *
* @param page the initial scrolled page. * @param searchHits the initial hits
* @param continueScrollFunction function to continue scrolling applies to the current scrollId. * @param continueScrollFunction function to continue scrolling applies to the current scrollId.
* @param clearScrollConsumer consumer to clear the scroll context by accepting the current scrollId. * @param clearScrollConsumer consumer to clear the scroll context by accepting the current scrollId.
* @param <T> * @param <T>
* @return the {@link CloseableIterator}. * @return the {@link SearchHitsIterator}.
*/ */
static <T> CloseableIterator<T> streamResults(ScrolledPage<T> page, static <T> SearchHitsIterator<T> streamResults(SearchScrollHits<T> searchHits,
Function<String, ScrolledPage<T>> continueScrollFunction, Consumer<String> clearScrollConsumer) { Function<String, SearchScrollHits<T>> continueScrollFunction, Consumer<String> clearScrollConsumer) {
Assert.notNull(page, "page must not be null."); Assert.notNull(searchHits, "searchHits must not be null.");
Assert.notNull(page.getScrollId(), "scrollId must not be null."); Assert.notNull(searchHits.getScrollId(), "scrollId of searchHits must not be null.");
Assert.notNull(continueScrollFunction, "continueScrollFunction must not be null."); Assert.notNull(continueScrollFunction, "continueScrollFunction must not be null.");
Assert.notNull(clearScrollConsumer, "clearScrollConsumer must not be null."); Assert.notNull(clearScrollConsumer, "clearScrollConsumer must not be null.");
return new CloseableIterator<T>() { Aggregations aggregations = searchHits.getAggregations();
float maxScore = searchHits.getMaxScore();
long totalHits = searchHits.getTotalHits();
TotalHitsRelation totalHitsRelation = searchHits.getTotalHitsRelation();
return new SearchHitsIterator<T>() {
// As we couldn't retrieve single result with scroll, store current hits. // As we couldn't retrieve single result with scroll, store current hits.
private volatile Iterator<T> scrollHits = page.iterator(); private volatile Iterator<SearchHit<T>> scrollHits = searchHits.iterator();
private volatile String scrollId = page.getScrollId(); private volatile String scrollId = searchHits.getScrollId();
private volatile boolean continueScroll = scrollHits.hasNext(); private volatile boolean continueScroll = scrollHits.hasNext();
@Override @Override
@ -67,6 +73,27 @@ abstract class StreamQueries {
} }
} }
@Override
@Nullable
public Aggregations getAggregations() {
return aggregations;
}
@Override
public float getMaxScore() {
return maxScore;
}
@Override
public long getTotalHits() {
return totalHits;
}
@Override
public TotalHitsRelation getTotalHitsRelation() {
return totalHitsRelation;
}
@Override @Override
public boolean hasNext() { public boolean hasNext() {
@ -75,7 +102,7 @@ abstract class StreamQueries {
} }
if (!scrollHits.hasNext()) { if (!scrollHits.hasNext()) {
ScrolledPage<T> nextPage = continueScrollFunction.apply(scrollId); SearchScrollHits<T> nextPage = continueScrollFunction.apply(scrollId);
scrollHits = nextPage.iterator(); scrollHits = nextPage.iterator();
scrollId = nextPage.getScrollId(); scrollId = nextPage.getScrollId();
continueScroll = scrollHits.hasNext(); continueScroll = scrollHits.hasNext();
@ -85,7 +112,7 @@ abstract class StreamQueries {
} }
@Override @Override
public T next() { public SearchHit<T> next() {
if (hasNext()) { if (hasNext()) {
return scrollHits.next(); return scrollHits.next();
} }

View File

@ -0,0 +1,30 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.core;
/**
* Enum to represent the relation that Elasticsearch returns for the totalHits value {@see <a href=
* "https://www.elastic.co/guide/en/elasticsearch/reference/7.5/search-request-body.html#request-body-search-track-total-hits">Ekasticsearch
* docs</a>}
*
* @author Peter-Josef Meisch
* @author Sascha Woo
* @since 4.0
*/
public enum TotalHitsRelation {
EQUAL_TO, //
GREATER_THAN_OR_EQUAL_TO
}

View File

@ -19,10 +19,9 @@ import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.springframework.data.convert.EntityConverter; import org.springframework.data.convert.EntityConverter;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.core.SearchHit; import org.springframework.data.elasticsearch.core.SearchHit;
import org.springframework.data.elasticsearch.core.SearchHits; import org.springframework.data.elasticsearch.core.SearchHits;
import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage; import org.springframework.data.elasticsearch.core.SearchScrollHits;
import org.springframework.data.elasticsearch.core.document.Document; import org.springframework.data.elasticsearch.core.document.Document;
import org.springframework.data.elasticsearch.core.document.SearchDocument; import org.springframework.data.elasticsearch.core.document.SearchDocument;
import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse; import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse;
@ -39,6 +38,7 @@ import org.springframework.util.Assert;
* @author Mohsin Husen * @author Mohsin Husen
* @author Christoph Strobl * @author Christoph Strobl
* @author Peter-Josef Meisch * @author Peter-Josef Meisch
* @author Sasch Woo
*/ */
public interface ElasticsearchConverter public interface ElasticsearchConverter
extends EntityConverter<ElasticsearchPersistentEntity<?>, ElasticsearchPersistentProperty, Object, Document> { extends EntityConverter<ElasticsearchPersistentEntity<?>, ElasticsearchPersistentProperty, Object, Document> {
@ -90,6 +90,17 @@ public interface ElasticsearchConverter
*/ */
<T> SearchHits<T> read(Class<T> type, SearchDocumentResponse searchDocumentResponse); <T> SearchHits<T> read(Class<T> type, SearchDocumentResponse searchDocumentResponse);
/**
* builds a {@link SearchScrollHits} from a {@link SearchDocumentResponse}.
*
* @param <T> the clazz of the type, must not be {@literal null}.
* @param type the type of the returned data, must not be {@literal null}.
* @param searchDocumentResponse the response to read from, must not be {@literal null}.
* @return a {@link SearchScrollHits} object
* @since 4.0
*/
<T> SearchScrollHits<T> readScroll(Class<T> type, SearchDocumentResponse searchDocumentResponse);
/** /**
* builds a {@link SearchHit} from a {@link SearchDocument}. * builds a {@link SearchHit} from a {@link SearchDocument}.
* *
@ -101,9 +112,6 @@ public interface ElasticsearchConverter
*/ */
<T> SearchHit<T> read(Class<T> type, SearchDocument searchDocument); <T> SearchHit<T> read(Class<T> type, SearchDocument searchDocument);
<T> AggregatedPage<SearchHit<T>> mapResults(SearchDocumentResponse response, Class<T> clazz,
@Nullable Pageable pageable);
// endregion // endregion
// region write // region write

View File

@ -38,13 +38,13 @@ import org.springframework.core.convert.ConversionService;
import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.convert.support.GenericConversionService; import org.springframework.core.convert.support.GenericConversionService;
import org.springframework.data.convert.CustomConversions; import org.springframework.data.convert.CustomConversions;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.ElasticsearchException; import org.springframework.data.elasticsearch.ElasticsearchException;
import org.springframework.data.elasticsearch.annotations.ScriptedField; import org.springframework.data.elasticsearch.annotations.ScriptedField;
import org.springframework.data.elasticsearch.core.SearchScrollHits;
import org.springframework.data.elasticsearch.core.SearchHit; import org.springframework.data.elasticsearch.core.SearchHit;
import org.springframework.data.elasticsearch.core.SearchHits; import org.springframework.data.elasticsearch.core.SearchHits;
import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage; import org.springframework.data.elasticsearch.core.SearchHitsImpl;
import org.springframework.data.elasticsearch.core.aggregation.impl.AggregatedPageImpl; import org.springframework.data.elasticsearch.core.TotalHitsRelation;
import org.springframework.data.elasticsearch.core.document.Document; import org.springframework.data.elasticsearch.core.document.Document;
import org.springframework.data.elasticsearch.core.document.SearchDocument; import org.springframework.data.elasticsearch.core.document.SearchDocument;
import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse; import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse;
@ -146,34 +146,9 @@ public class MappingElasticsearchConverter
// region read // region read
@Override
public <T> AggregatedPage<SearchHit<T>> mapResults(SearchDocumentResponse response, Class<T> type,
@Nullable Pageable pageable) {
List<SearchHit<T>> results = response.getSearchDocuments().stream() //
.map(searchDocument -> read(type, searchDocument)) //
.collect(Collectors.toList());
return new AggregatedPageImpl<>(results, pageable, response);
}
@Override @Override
public <T> SearchHits<T> read(Class<T> type, SearchDocumentResponse searchDocumentResponse) { public <T> SearchHits<T> read(Class<T> type, SearchDocumentResponse searchDocumentResponse) {
return readResponse(type, searchDocumentResponse);
Assert.notNull(type, "type must not be null");
Assert.notNull(searchDocumentResponse, "searchDocumentResponse must not be null");
long totalHits = searchDocumentResponse.getTotalHits();
float maxScore = searchDocumentResponse.getMaxScore();
String scrollId = searchDocumentResponse.getScrollId();
List<SearchHit<T>> searchHits = searchDocumentResponse.getSearchDocuments().stream() //
.map(searchDocument -> read(type, searchDocument)) //
.collect(Collectors.toList());
Aggregations aggregations = searchDocumentResponse.getAggregations();
SearchHits.TotalHitsRelation totalHitsRelation = SearchHits.TotalHitsRelation
.valueOf(searchDocumentResponse.getTotalHitsRelation());
return new SearchHits<>(totalHits, totalHitsRelation, maxScore, scrollId, searchHits, aggregations);
} }
@Override @Override
@ -191,6 +166,29 @@ public class MappingElasticsearchConverter
return new SearchHit<T>(id, score, sortValues, highlightFields, content); return new SearchHit<T>(id, score, sortValues, highlightFields, content);
} }
@Override
public <T> SearchScrollHits<T> readScroll(Class<T> type, SearchDocumentResponse searchDocumentResponse) {
return readResponse(type, searchDocumentResponse);
}
private <T> SearchHitsImpl<T> readResponse(Class<T> type, SearchDocumentResponse searchDocumentResponse) {
Assert.notNull(type, "type must not be null");
Assert.notNull(searchDocumentResponse, "searchDocumentResponse must not be null");
long totalHits = searchDocumentResponse.getTotalHits();
float maxScore = searchDocumentResponse.getMaxScore();
String scrollId = searchDocumentResponse.getScrollId();
List<SearchHit<T>> searchHits = searchDocumentResponse.getSearchDocuments().stream() //
.map(searchDocument -> read(type, searchDocument)) //
.collect(Collectors.toList());
Aggregations aggregations = searchDocumentResponse.getAggregations();
TotalHitsRelation totalHitsRelation = TotalHitsRelation
.valueOf(searchDocumentResponse.getTotalHitsRelation());
return new SearchHitsImpl<>(totalHits, totalHitsRelation, maxScore, scrollId, searchHits, aggregations);
}
@Nullable @Nullable
private Map<String, List<String>> getHighlightsAndRemapFieldNames(Class<?> type, SearchDocument searchDocument) { private Map<String, List<String>> getHighlightsAndRemapFieldNames(Class<?> type, SearchDocument searchDocument) {
Map<String, List<String>> highlightFields = searchDocument.getHighlightFields(); Map<String, List<String>> highlightFields = searchDocument.getHighlightFields();

View File

@ -28,17 +28,12 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.lang.Double;
import java.lang.Integer;
import java.lang.Long;
import java.lang.Object;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -290,7 +285,7 @@ public abstract class ElasticsearchTemplateTests {
// then // then
assertThat(searchHits).isNotNull(); assertThat(searchHits).isNotNull();
assertThat(searchHits.getTotalHits()).isEqualTo(1); assertThat(searchHits.getTotalHits()).isEqualTo(1);
assertThat(searchHits.getTotalHitsRelation()).isEqualByComparingTo(SearchHits.TotalHitsRelation.EQUAL_TO); assertThat(searchHits.getTotalHitsRelation()).isEqualByComparingTo(TotalHitsRelation.EQUAL_TO);
} }
@Test // DATAES-595 @Test // DATAES-595
@ -1055,11 +1050,11 @@ public abstract class ElasticsearchTemplateTests {
CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria()); CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria());
criteriaQuery.setPageable(PageRequest.of(0, 10)); criteriaQuery.setPageable(PageRequest.of(0, 10));
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
criteriaQuery, SampleEntity.class, index); criteriaQuery, SampleEntity.class, index);
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000,
SampleEntity.class); SampleEntity.class);
} }
@ -1082,11 +1077,11 @@ public abstract class ElasticsearchTemplateTests {
NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery())
.withPageable(PageRequest.of(0, 10)).build(); .withPageable(PageRequest.of(0, 10)).build();
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
searchQuery, SampleEntity.class, index); searchQuery, SampleEntity.class, index);
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000,
SampleEntity.class); SampleEntity.class);
} }
@ -1109,12 +1104,12 @@ public abstract class ElasticsearchTemplateTests {
criteriaQuery.addFields("message"); criteriaQuery.addFields("message");
criteriaQuery.setPageable(PageRequest.of(0, 10)); criteriaQuery.setPageable(PageRequest.of(0, 10));
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
criteriaQuery, SampleEntity.class, index); criteriaQuery, SampleEntity.class, index);
String scrollId = scroll.getScrollId(); String scrollId = scroll.getScrollId();
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scrollId = scroll.getScrollId(); scrollId = scroll.getScrollId();
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class);
} }
@ -1136,12 +1131,12 @@ public abstract class ElasticsearchTemplateTests {
NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).withFields("message") NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).withFields("message")
.withQuery(matchAllQuery()).withPageable(PageRequest.of(0, 10)).build(); .withQuery(matchAllQuery()).withPageable(PageRequest.of(0, 10)).build();
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
searchQuery, SampleEntity.class, index); searchQuery, SampleEntity.class, index);
String scrollId = scroll.getScrollId(); String scrollId = scroll.getScrollId();
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scrollId = scroll.getScrollId(); scrollId = scroll.getScrollId();
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class);
} }
@ -1163,12 +1158,12 @@ public abstract class ElasticsearchTemplateTests {
CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria()); CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria());
criteriaQuery.setPageable(PageRequest.of(0, 10)); criteriaQuery.setPageable(PageRequest.of(0, 10));
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
criteriaQuery, SampleEntity.class, index); criteriaQuery, SampleEntity.class, index);
String scrollId = scroll.getScrollId(); String scrollId = scroll.getScrollId();
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scrollId = scroll.getScrollId(); scrollId = scroll.getScrollId();
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class);
} }
@ -1190,12 +1185,12 @@ public abstract class ElasticsearchTemplateTests {
NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery())
.withPageable(PageRequest.of(0, 10)).build(); .withPageable(PageRequest.of(0, 10)).build();
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
searchQuery, SampleEntity.class, index); searchQuery, SampleEntity.class, index);
String scrollId = scroll.getScrollId(); String scrollId = scroll.getScrollId();
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scrollId = scroll.getScrollId(); scrollId = scroll.getScrollId();
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class);
} }
@ -1217,12 +1212,12 @@ public abstract class ElasticsearchTemplateTests {
CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria()); CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria());
criteriaQuery.setPageable(PageRequest.of(0, 10)); criteriaQuery.setPageable(PageRequest.of(0, 10));
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
criteriaQuery, SampleEntity.class, index); criteriaQuery, SampleEntity.class, index);
String scrollId = scroll.getScrollId(); String scrollId = scroll.getScrollId();
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scrollId = scroll.getScrollId(); scrollId = scroll.getScrollId();
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class);
} }
@ -1244,12 +1239,12 @@ public abstract class ElasticsearchTemplateTests {
NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery())
.withPageable(PageRequest.of(0, 10)).build(); .withPageable(PageRequest.of(0, 10)).build();
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
searchQuery, SampleEntity.class, index); searchQuery, SampleEntity.class, index);
String scrollId = scroll.getScrollId(); String scrollId = scroll.getScrollId();
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scrollId = scroll.getScrollId(); scrollId = scroll.getScrollId();
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class);
} }
@ -1529,16 +1524,16 @@ public abstract class ElasticsearchTemplateTests {
NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery())
.withIndicesOptions(IndicesOptions.lenientExpandOpen()).build(); .withIndicesOptions(IndicesOptions.lenientExpandOpen()).build();
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations) SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations)
.searchScrollStart(scrollTimeInMillis, searchQuery, SampleEntity.class, index); .searchScrollStart(scrollTimeInMillis, searchQuery, SampleEntity.class, index);
List<SearchHit<SampleEntity>> entities = new ArrayList<>(scroll.getContent()); List<SearchHit<SampleEntity>> entities = new ArrayList<>(scroll.getSearchHits());
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(),
scrollTimeInMillis, SampleEntity.class); scrollTimeInMillis, SampleEntity.class);
entities.addAll(scroll.getContent()); entities.addAll(scroll.getSearchHits());
} }
// then // then
@ -2431,11 +2426,11 @@ public abstract class ElasticsearchTemplateTests {
CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria("message").contains("message")); CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria("message").contains("message"));
criteriaQuery.setPageable(PageRequest.of(0, 10)); criteriaQuery.setPageable(PageRequest.of(0, 10));
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
criteriaQuery, SampleEntity.class, index); criteriaQuery, SampleEntity.class, index);
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000,
SampleEntity.class); SampleEntity.class);
} }
@ -2469,11 +2464,11 @@ public abstract class ElasticsearchTemplateTests {
NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchQuery("message", "message")) NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchQuery("message", "message"))
.withPageable(PageRequest.of(0, 10)).build(); .withPageable(PageRequest.of(0, 10)).build();
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
searchQuery, SampleEntity.class, index); searchQuery, SampleEntity.class, index);
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000,
SampleEntity.class); SampleEntity.class);
} }
@ -2502,11 +2497,11 @@ public abstract class ElasticsearchTemplateTests {
NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery())
.withPageable(PageRequest.of(0, 10)).withSourceFilter(sourceFilter).build(); .withPageable(PageRequest.of(0, 10)).withSourceFilter(sourceFilter).build();
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
searchQuery, SampleEntity.class, index); searchQuery, SampleEntity.class, index);
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000,
SampleEntity.class); SampleEntity.class);
} }
@ -2549,11 +2544,11 @@ public abstract class ElasticsearchTemplateTests {
.withSort(new FieldSortBuilder("message").order(SortOrder.DESC)).withPageable(PageRequest.of(0, 10)).build(); .withSort(new FieldSortBuilder("message").order(SortOrder.DESC)).withPageable(PageRequest.of(0, 10)).build();
// when // when
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
searchQuery, SampleEntity.class, index); searchQuery, SampleEntity.class, index);
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000,
SampleEntity.class); SampleEntity.class);
} }
@ -2598,11 +2593,11 @@ public abstract class ElasticsearchTemplateTests {
.build(); .build();
// when // when
ScrolledPage<SearchHit<SampleEntity>> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, SearchScrollHits<SampleEntity> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000,
searchQuery, SampleEntity.class, index); searchQuery, SampleEntity.class, index);
List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>(); List<SearchHit<SampleEntity>> sampleEntities = new ArrayList<>();
while (scroll.hasContent()) { while (scroll.hasSearchHits()) {
sampleEntities.addAll(scroll.getContent()); sampleEntities.addAll(scroll.getSearchHits());
scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000,
SampleEntity.class); SampleEntity.class);
} }

View File

@ -19,12 +19,14 @@ import static org.assertj.core.api.Assertions.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import org.elasticsearch.search.aggregations.Aggregations;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageImpl;
import org.springframework.data.util.CloseableIterator; import org.springframework.data.domain.Pageable;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
/** /**
@ -36,42 +38,51 @@ public class StreamQueriesTest {
public void shouldCallClearScrollOnIteratorClose() { public void shouldCallClearScrollOnIteratorClose() {
// given // given
List<String> results = new ArrayList<>(); List<SearchHit<String>> hits = new ArrayList<>();
results.add("one"); hits.add(new SearchHit<String>(null, 0, null, null, "one"));
ScrolledPage<String> page = new ScrolledPageImpl("1234", results); SearchScrollHits<String> searchHits = newSearchScrollHits(hits);
AtomicBoolean clearScrollCalled = new AtomicBoolean(false); AtomicBoolean clearScrollCalled = new AtomicBoolean(false);
// when // when
CloseableIterator<String> closeableIterator = StreamQueries.streamResults( // SearchHitsIterator<String> iterator = StreamQueries.streamResults( //
page, // searchHits, //
scrollId -> new ScrolledPageImpl(scrollId, Collections.emptyList()), // scrollId -> newSearchScrollHits(Collections.emptyList()), //
scrollId -> clearScrollCalled.set(true)); scrollId -> clearScrollCalled.set(true));
while (closeableIterator.hasNext()) { while (iterator.hasNext()) {
closeableIterator.next(); iterator.next();
} }
closeableIterator.close(); iterator.close();
// then // then
assertThat(clearScrollCalled).isTrue(); assertThat(clearScrollCalled).isTrue();
} }
private static class ScrolledPageImpl extends PageImpl<String> implements ScrolledPage<String> { @Test // DATAES-766
public void shouldReturnTotalHits() {
private String scrollId; // given
List<SearchHit<String>> hits = new ArrayList<>();
hits.add(new SearchHit<String>(null, 0, null, null, "one"));
SearchScrollHits<String> searchHits = newSearchScrollHits(hits);
// when
SearchHitsIterator<String> iterator = StreamQueries.streamResults( //
searchHits, //
scrollId -> newSearchScrollHits(Collections.emptyList()), //
scrollId -> {
});
// then
assertThat(iterator.getTotalHits()).isEqualTo(1);
public ScrolledPageImpl(String scrollId, List<String> content) {
super(content);
this.scrollId = scrollId;
} }
@Override private SearchScrollHits<String> newSearchScrollHits(List<SearchHit<String>> hits) {
@Nullable return new SearchHitsImpl<String>(hits.size(), TotalHitsRelation.EQUAL_TO, 0, "1234", hits, null);
public String getScrollId() {
return scrollId;
}
} }
} }