diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchOperations.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchOperations.java index 3b27d5cc6..eece21e09 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchOperations.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchOperations.java @@ -16,7 +16,6 @@ package org.springframework.data.elasticsearch.core; import org.elasticsearch.action.update.UpdateResponse; -import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.AliasMetaData; import org.elasticsearch.common.Nullable; import org.springframework.data.domain.Page; @@ -28,6 +27,7 @@ import org.springframework.data.util.CloseableIterator; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; /** * ElasticsearchOperations @@ -35,6 +35,7 @@ import java.util.Map; * @author Rizwan Idrees * @author Mohsin Husen * @author Kevin Leturc + * @author Zetang Zeng */ public interface ElasticsearchOperations { @@ -186,6 +187,42 @@ public interface ElasticsearchOperations { */ Page queryForPage(SearchQuery query, Class clazz, SearchResultMapper mapper); + /** + * Execute the multi-search against elasticsearch and return result as {@link List} of {@link Page} + * + * @param queries + * @param clazz + * @return + */ + List> queryForPage(List queries, Class clazz); + + /** + * Execute the multi-search against elasticsearch and return result as {@link List} of {@link Page} using custom mapper + * + * @param queries + * @param clazz + * @return + */ + List> queryForPage(List queries, Class clazz, SearchResultMapper mapper); + + /** + * Execute the multi-search against elasticsearch and return result as {@link List} of {@link Page} + * + * @param queries + * @param classes + * @return + */ + List> queryForPage(List queries, List> classes); + + /** + * Execute the multi-search against elasticsearch and return result as {@link List} of {@link Page} using custom mapper + * + * @param queries + * @param classes + * @return + */ + List> queryForPage(List queries, List> classes, SearchResultMapper mapper); + /** * Execute the query against elasticsearch and return result as {@link Page} * @@ -283,6 +320,29 @@ public interface ElasticsearchOperations { */ List queryForList(SearchQuery query, Class clazz); + /** + * Execute the multi search query against elasticsearch and return result as {@link List} + * + * @param queries + * @param clazz + * @param + * @return + */ + default List> queryForList(List queries, Class clazz) { + return queryForPage(queries, clazz).stream().map(Page::getContent).collect(Collectors.toList()); + } + + /** + * Execute the multi search query against elasticsearch and return result as {@link List} + * + * @param queries + * @param classes + * @return + */ + default List> queryForList(List queries, List> classes) { + return queryForPage(queries, classes).stream().map(Page::getContent).collect(Collectors.toList()); + } + /** * Execute the query against elasticsearch and return ids * diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java index 68e14ede5..a53f67200 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java @@ -47,6 +47,8 @@ import org.elasticsearch.action.get.MultiGetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.ClearScrollRequest; import org.elasticsearch.action.search.ClearScrollResponse; +import org.elasticsearch.action.search.MultiSearchRequest; +import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchScrollRequest; @@ -127,6 +129,7 @@ import org.springframework.util.StringUtils; * @author Sascha Woo * @author Ted Liang * @author Don Wellington + * @author Zetang Zeng */ public class ElasticsearchRestTemplate implements ElasticsearchOperations, EsClient, ApplicationContextAware { @@ -335,6 +338,68 @@ public class ElasticsearchRestTemplate return mapper.mapResults(response, clazz, query.getPageable()); } + @Override + public List> queryForPage(List queries, Class clazz) { + return queryForPage(queries, clazz, resultsMapper); + } + + private List> doMultiSearch(List queries, Class clazz, MultiSearchRequest request, SearchResultMapper resultsMapper) { + MultiSearchResponse.Item[] items = getMultiSearchResult(request); + List> res = new ArrayList<>(queries.size()); + int c = 0; + for (SearchQuery query : queries) { + res.add(resultsMapper.mapResults(items[c++].getResponse(), clazz, query.getPageable())); + } + return res; + } + + private List> doMultiSearch(List queries, List> classes, MultiSearchRequest request, SearchResultMapper resultsMapper) { + MultiSearchResponse.Item[] items = getMultiSearchResult(request); + List> res = new ArrayList<>(queries.size()); + int c = 0; + Iterator> it = classes.iterator(); + for (SearchQuery query : queries) { + res.add(resultsMapper.mapResults(items[c++].getResponse(), it.next(), query.getPageable())); + } + return res; + } + + private MultiSearchResponse.Item[] getMultiSearchResult(MultiSearchRequest request) { + MultiSearchResponse response; + try { + response = client.multiSearch(request); + } catch (IOException e) { + throw new ElasticsearchException("Error for search request: " + request.toString(), e); + } + MultiSearchResponse.Item[] items = response.getResponses(); + Assert.isTrue(items.length == request.requests().size(), "Response should has same length with queries"); + return items; + } + + @Override + public List> queryForPage(List queries, Class clazz, SearchResultMapper mapper) { + MultiSearchRequest request = new MultiSearchRequest(); + for (SearchQuery query : queries) { + request.add(prepareSearch(prepareSearch(query, clazz), query)); + } + return doMultiSearch(queries, clazz, request, mapper); + } + + @Override + public List> queryForPage(List queries, List> classes) { + return queryForPage(queries, classes, resultsMapper); + } + + @Override + public List> queryForPage(List queries, List> classes, SearchResultMapper mapper) { + MultiSearchRequest request = new MultiSearchRequest(); + Iterator> it = classes.iterator(); + for (SearchQuery query : queries) { + request.add(prepareSearch(prepareSearch(query, it.next()), query)); + } + return doMultiSearch(queries, classes, request, mapper); + } + @Override public T query(SearchQuery query, ResultsExtractor resultsExtractor) { SearchResponse response = doSearch(prepareSearch(query, Optional.ofNullable(query.getQuery())), query); @@ -1026,6 +1091,16 @@ public class ElasticsearchRestTemplate } private SearchResponse doSearch(SearchRequest searchRequest, SearchQuery searchQuery) { + prepareSearch(searchRequest, searchQuery); + + try { + return client.search(searchRequest); + } catch (IOException e) { + throw new ElasticsearchException("Error for search request with scroll: " + searchRequest.toString(), e); + } + } + + private SearchRequest prepareSearch(SearchRequest searchRequest, SearchQuery searchQuery) { if (searchQuery.getFilter() != null) { searchRequest.source().postFilter(searchQuery.getFilter()); } @@ -1074,12 +1149,7 @@ public class ElasticsearchRestTemplate searchRequest.source().aggregation(aggregatedFacet.getFacet()); } } - - try { - return client.search(searchRequest); - } catch (IOException e) { - throw new ElasticsearchException("Error for search request with scroll: " + searchRequest.toString(), e); - } + return searchRequest; } private SearchResponse getSearchResponse(ActionFuture response) { diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java index b86cca416..dd4eb381c 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java @@ -50,6 +50,8 @@ import org.elasticsearch.action.get.MultiGetRequest; import org.elasticsearch.action.get.MultiGetRequestBuilder; import org.elasticsearch.action.get.MultiGetResponse; import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.MultiSearchRequest; +import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.update.UpdateRequestBuilder; @@ -134,6 +136,7 @@ import org.springframework.util.StringUtils; * @author Sascha Woo * @author Ted Liang * @author Jean-Baptiste Nizet + * @author Zetang Zeng */ public class ElasticsearchTemplate implements ElasticsearchOperations, EsClient, ApplicationContextAware { @@ -314,6 +317,65 @@ public class ElasticsearchTemplate implements ElasticsearchOperations, EsClient< return mapper.mapResults(response, clazz, query.getPageable()); } + @Override + public List> queryForPage(List queries, Class clazz) { + return queryForPage(queries, clazz, resultsMapper); + } + + @Override + public List> queryForPage(List queries, Class clazz, SearchResultMapper mapper) { + MultiSearchRequest request = new MultiSearchRequest(); + for (SearchQuery query : queries) { + request.add(prepareSearch(prepareSearch(query, clazz), query)); + } + return doMultiSearch(queries, clazz, request, mapper); + } + + private List> doMultiSearch(List queries, Class clazz, MultiSearchRequest request, SearchResultMapper resultsMapper) { + MultiSearchResponse.Item[] items = getMultiSearchResult(request); + List> res = new ArrayList<>(queries.size()); + int c = 0; + for (SearchQuery query : queries) { + res.add(resultsMapper.mapResults(items[c++].getResponse(), clazz, query.getPageable())); + } + return res; + } + + private List> doMultiSearch(List queries, List> classes, MultiSearchRequest request, SearchResultMapper resultsMapper) { + MultiSearchResponse.Item[] items = getMultiSearchResult(request); + List> res = new ArrayList<>(queries.size()); + int c = 0; + Iterator> it = classes.iterator(); + for (SearchQuery query : queries) { + res.add(resultsMapper.mapResults(items[c++].getResponse(), it.next(), query.getPageable())); + } + return res; + } + + private MultiSearchResponse.Item[] getMultiSearchResult(MultiSearchRequest request) { + ActionFuture future = client.multiSearch(request); + MultiSearchResponse response = future.actionGet(); + MultiSearchResponse.Item[] items = response.getResponses(); + Assert.isTrue(items.length == request.requests().size(), "Response should have same length with queries"); + return items; + } + + @Override + public List> queryForPage(List queries, List> classes) { + return queryForPage(queries, classes, resultsMapper); + } + + @Override + public List> queryForPage(List queries, List> classes, SearchResultMapper mapper) { + Assert.isTrue(queries.size() == classes.size(), "Queries should have same length with classes"); + MultiSearchRequest request = new MultiSearchRequest(); + Iterator> it = classes.iterator(); + for (SearchQuery query : queries) { + request.add(prepareSearch(prepareSearch(query, it.next()), query)); + } + return doMultiSearch(queries, classes, request, mapper); + } + @Override public T query(SearchQuery query, ResultsExtractor resultsExtractor) { SearchResponse response = doSearch(prepareSearch(query), query); @@ -887,6 +949,11 @@ public class ElasticsearchTemplate implements ElasticsearchOperations, EsClient< } private SearchResponse doSearch(SearchRequestBuilder searchRequest, SearchQuery searchQuery) { + SearchRequestBuilder requestBuilder = prepareSearch(searchRequest, searchQuery); + return getSearchResponse(requestBuilder); + } + + private SearchRequestBuilder prepareSearch(SearchRequestBuilder searchRequest, SearchQuery searchQuery) { if (searchQuery.getFilter() != null) { searchRequest.setPostFilter(searchQuery.getFilter()); } @@ -935,7 +1002,7 @@ public class ElasticsearchTemplate implements ElasticsearchOperations, EsClient< searchRequest.addAggregation(aggregatedFacet.getFacet()); } } - return getSearchResponse(searchRequest.setQuery(searchQuery.getQuery())); + return searchRequest.setQuery(searchQuery.getQuery()); } private SearchResponse getSearchResponse(SearchRequestBuilder requestBuilder) { diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java index 0b38fb135..d3a1b13b5 100755 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.UUID; import org.apache.commons.lang.StringUtils; +import org.assertj.core.util.Lists; import org.elasticsearch.action.get.MultiGetItemResponse; import org.elasticsearch.action.get.MultiGetResponse; import org.elasticsearch.action.index.IndexRequest; @@ -51,6 +52,7 @@ import org.springframework.data.elasticsearch.annotations.Document; import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage; import org.springframework.data.elasticsearch.core.aggregation.impl.AggregatedPageImpl; import org.springframework.data.elasticsearch.core.query.*; +import org.springframework.data.elasticsearch.entities.Book; import org.springframework.data.elasticsearch.entities.HetroEntity1; import org.springframework.data.elasticsearch.entities.HetroEntity2; import org.springframework.data.elasticsearch.entities.SampleEntity; @@ -76,6 +78,7 @@ import static org.springframework.data.elasticsearch.utils.IndexBuilder.*; * @author Alen Turkovic * @author Sascha Woo * @author Jean-Baptiste Nizet + * @author Zetang Zeng */ @Ignore @@ -1486,6 +1489,60 @@ public class ElasticsearchTemplateTests { }); } + @Test // DATAES-487 + public void shouldReturnSameEntityForMultiSearch() { + // given + List indexQueries = new ArrayList<>(); + + indexQueries.add(buildIndex(SampleEntity.builder().id("1").message("ab").build())); + indexQueries.add(buildIndex(SampleEntity.builder().id("2").message("bc").build())); + indexQueries.add(buildIndex(SampleEntity.builder().id("3").message("ac").build())); + + elasticsearchTemplate.bulkIndex(indexQueries); + elasticsearchTemplate.refresh(SampleEntity.class); + // when + List queries = new ArrayList<>(); + + queries.add(new NativeSearchQueryBuilder().withQuery(termQuery("message", "ab")).build()); + queries.add(new NativeSearchQueryBuilder().withQuery(termQuery("message", "bc")).build()); + queries.add(new NativeSearchQueryBuilder().withQuery(termQuery("message", "ac")).build()); + // then + List> sampleEntities = elasticsearchTemplate.queryForPage(queries, SampleEntity.class); + for (Page sampleEntity : sampleEntities) { + assertThat(sampleEntity.getTotalElements(), equalTo(1L)); + } + } + + @Test // DATAES-487 + public void shouldReturnDifferentEntityForMultiSearch() { + // given + Class clazz = Book.class; + elasticsearchTemplate.deleteIndex(clazz); + elasticsearchTemplate.createIndex(clazz); + elasticsearchTemplate.putMapping(clazz); + elasticsearchTemplate.refresh(clazz); + + List indexQueries = new ArrayList<>(); + + indexQueries.add(buildIndex(SampleEntity.builder().id("1").message("ab").build())); + indexQueries.add(buildIndex(Book.builder().id("2").description("bc").build())); + + elasticsearchTemplate.bulkIndex(indexQueries); + elasticsearchTemplate.refresh(SampleEntity.class); + elasticsearchTemplate.refresh(clazz); + // when + List queries = new ArrayList<>(); + + queries.add(new NativeSearchQueryBuilder().withQuery(termQuery("message", "ab")).build()); + queries.add(new NativeSearchQueryBuilder().withQuery(termQuery("description", "bc")).build()); + // then + List> pages = elasticsearchTemplate.queryForPage(queries, Lists.newArrayList(SampleEntity.class, clazz)); + assertThat(pages.get(0).getTotalElements(), equalTo(1L)); + assertThat(pages.get(0).getContent().get(0).getClass(), equalTo(SampleEntity.class)); + assertThat(pages.get(1).getTotalElements(), equalTo(1L)); + assertThat(pages.get(1).getContent().get(0).getClass(), equalTo(clazz)); + } + @Test public void shouldDeleteDocumentBySpecifiedTypeUsingDeleteQuery() { // given