diff --git a/src/main/java/org/springframework/data/elasticsearch/core/SearchHitSupport.java b/src/main/java/org/springframework/data/elasticsearch/core/SearchHitSupport.java index 02f8547c1..39b3b65a9 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/SearchHitSupport.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/SearchHitSupport.java @@ -158,5 +158,13 @@ public final class SearchHitSupport { public SearchHits getSearchHits() { return searchHits; } + + /* + * return the same instance as in getSearchHits().getSearchHits() + */ + @Override + public List> getContent() { + return searchHits.getSearchHits(); + } } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/SearchHitsImpl.java b/src/main/java/org/springframework/data/elasticsearch/core/SearchHitsImpl.java index e8bf45452..4d0c77aaf 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/SearchHitsImpl.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/SearchHitsImpl.java @@ -19,11 +19,12 @@ import java.util.Collections; import java.util.List; import org.elasticsearch.search.aggregations.Aggregations; +import org.springframework.data.util.Lazy; import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** - * Basic implementation of {@link SearchScrollHits} + * Basic implementation of {@link SearchScrollHits} * * @param the result data class. * @author Peter-Josef Meisch @@ -35,9 +36,10 @@ public class SearchHitsImpl implements SearchScrollHits { private final long totalHits; private final TotalHitsRelation totalHitsRelation; private final float maxScore; - private final String scrollId; + @Nullable private final String scrollId; private final List> searchHits; - private final Aggregations aggregations; + private final Lazy>> unmodifiableSearchHits; + @Nullable private final Aggregations aggregations; /** * @param totalHits the number of total hits for the search @@ -58,6 +60,7 @@ public class SearchHitsImpl implements SearchScrollHits { this.scrollId = scrollId; this.searchHits = searchHits; this.aggregations = aggregations; + this.unmodifiableSearchHits = Lazy.of(() -> Collections.unmodifiableList(searchHits)); } // region getter @@ -84,7 +87,7 @@ public class SearchHitsImpl implements SearchScrollHits { @Override public List> getSearchHits() { - return Collections.unmodifiableList(searchHits); + return unmodifiableSearchHits.get(); } // endregion diff --git a/src/main/java/org/springframework/data/elasticsearch/core/SearchScrollHits.java b/src/main/java/org/springframework/data/elasticsearch/core/SearchScrollHits.java index cb7d07694..9c70583c7 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/SearchScrollHits.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/SearchScrollHits.java @@ -15,12 +15,15 @@ */ package org.springframework.data.elasticsearch.core; +import org.springframework.lang.Nullable; + /** * This interface is used to expose the current {@code scrollId} from the underlying scroll context. *

* Internal use only. * * @author Sascha Woo + * @author Peter-Josef Meisch * @param * @since 4.0 */ @@ -29,6 +32,7 @@ public interface SearchScrollHits extends SearchHits { /** * @return the scroll id */ + @Nullable String getScrollId(); } diff --git a/src/test/java/org/springframework/data/elasticsearch/core/SearchHitSupportTest.java b/src/test/java/org/springframework/data/elasticsearch/core/SearchHitSupportTest.java index 9bd5bc4af..4911f1e0f 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/SearchHitSupportTest.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/SearchHitSupportTest.java @@ -19,17 +19,23 @@ import static java.util.Collections.*; import static org.assertj.core.api.Assertions.*; import static org.mockito.Mockito.*; +import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; +import java.util.List; import org.elasticsearch.search.aggregations.Aggregations; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; +import org.springframework.data.domain.PageRequest; import org.springframework.data.util.CloseableIterator; /** * @author Roman Puchkovskiy + * @author Peter-Josef Meisch */ class SearchHitSupportTest { + @Test // DATAES-772 void unwrapsSearchHitsIteratorToCloseableIteratorOfEntities() { TestStringSearchHitsIterator searchHitsIterator = new TestStringSearchHitsIterator(); @@ -38,6 +44,7 @@ class SearchHitSupportTest { CloseableIterator unwrappedIterator = (CloseableIterator) SearchHitSupport .unwrapSearchHits(searchHitsIterator); + // noinspection ConstantConditions assertThat(unwrappedIterator.next()).isEqualTo("one"); assertThat(unwrappedIterator.next()).isEqualTo("two"); assertThat(unwrappedIterator.hasNext()).isFalse(); @@ -47,6 +54,27 @@ class SearchHitSupportTest { assertThat(searchHitsIterator.closed).isTrue(); } + @Test // DATAES-952 + @DisplayName("should return the same list instance in SearchHits and getContent") + void shouldReturnTheSameListInstanceInSearchHitsAndGetContent() { + + List> hits = new ArrayList<>(); + hits.add(new SearchHit<>(null, null, 0, null, null, "one")); + hits.add(new SearchHit<>(null, null, 0, null, null, "two")); + hits.add(new SearchHit<>(null, null, 0, null, null, "three")); + hits.add(new SearchHit<>(null, null, 0, null, null, "four")); + hits.add(new SearchHit<>(null, null, 0, null, null, "five")); + + SearchHits originalSearchHits = new SearchHitsImpl<>(hits.size(), TotalHitsRelation.EQUAL_TO, 0, "scroll", + hits, null); + + SearchPage searchPage = SearchHitSupport.searchPageFor(originalSearchHits, PageRequest.of(0, 3)); + SearchHits searchHits = searchPage.getSearchHits(); + + assertThat(searchHits).isEqualTo(originalSearchHits); + assertThat(searchHits.getSearchHits()).isSameAs(searchPage.getContent()); + } + private static class TestStringSearchHitsIterator implements SearchHitsIterator { private final Iterator iterator = Arrays.asList("one", "two").iterator(); private boolean closed = false; @@ -87,4 +115,5 @@ class SearchHitSupportTest { return new SearchHit<>("index", "id", 1.0f, new Object[0], emptyMap(), nextString); } } + }