diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java index 0ea8705e1..f06c9e6c1 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java @@ -945,6 +945,11 @@ public class ElasticsearchRestTemplate if (!isEmpty(query.getFields())) { searchSourceBuilder.fetchSource(toArray(query.getFields()), null); } + + if (query.getSort() != null) { + prepareSort(query, searchSourceBuilder); + } + request.source(searchSourceBuilder); return request; } @@ -992,6 +997,12 @@ public class ElasticsearchRestTemplate } request.source().version(true); + if (!isEmpty(searchQuery.getElasticsearchSorts())) { + for (SortBuilder sort : searchQuery.getElasticsearchSorts()) { + request.source().sort(sort); + } + } + try { return client.search(request); } catch (IOException e) { @@ -1304,16 +1315,7 @@ public class ElasticsearchRestTemplate } if (query.getSort() != null) { - for (Sort.Order order : query.getSort()) { - FieldSortBuilder sort = SortBuilders.fieldSort(order.getProperty()) - .order(order.getDirection().isDescending() ? SortOrder.DESC : SortOrder.ASC); - if (order.getNullHandling() == Sort.NullHandling.NULLS_FIRST) { - sort.missing("_first"); - } else if (order.getNullHandling() == Sort.NullHandling.NULLS_LAST) { - sort.missing("_last"); - } - sourceBuilder.sort(sort); - } + prepareSort(query, sourceBuilder); } if (query.getMinScore() > 0) { @@ -1323,6 +1325,19 @@ public class ElasticsearchRestTemplate return request; } + private void prepareSort(Query query, SearchSourceBuilder sourceBuilder) { + for (Sort.Order order : query.getSort()) { + FieldSortBuilder sort = SortBuilders.fieldSort(order.getProperty()) + .order(order.getDirection().isDescending() ? SortOrder.DESC : SortOrder.ASC); + if (order.getNullHandling() == Sort.NullHandling.NULLS_FIRST) { + sort.missing("_first"); + } else if (order.getNullHandling() == Sort.NullHandling.NULLS_LAST) { + sort.missing("_last"); + } + sourceBuilder.sort(sort); + } + } + private IndexRequest prepareIndex(IndexQuery query) { try { String indexName = StringUtils.isEmpty(query.getIndexName()) diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java index 27ae36ed7..f4c34713b 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java @@ -825,6 +825,11 @@ public class ElasticsearchTemplate implements ElasticsearchOperations, EsClient< if (!isEmpty(query.getFields())) { requestBuilder.setFetchSource(toArray(query.getFields()), null); } + + if (query.getSort() != null) { + prepareSort(query, requestBuilder); + } + return requestBuilder; } @@ -859,6 +864,12 @@ public class ElasticsearchTemplate implements ElasticsearchOperations, EsClient< requestBuilder.setPostFilter(searchQuery.getFilter()); } + if (!isEmpty(searchQuery.getElasticsearchSorts())) { + for (SortBuilder sort : searchQuery.getElasticsearchSorts()) { + requestBuilder.addSort(sort); + } + } + return getSearchResponse(requestBuilder.setQuery(searchQuery.getQuery())); } @@ -1110,29 +1121,7 @@ public class ElasticsearchTemplate implements ElasticsearchOperations, EsClient< } if (query.getSort() != null) { - for (Sort.Order order : query.getSort()) { - SortOrder sortOrder = order.getDirection().isDescending() ? SortOrder.DESC : SortOrder.ASC; - - if (FIELD_SCORE.equals(order.getProperty())) { - ScoreSortBuilder sort = SortBuilders // - .scoreSort() // - .order(sortOrder); - - searchRequestBuilder.addSort(sort); - } else { - FieldSortBuilder sort = SortBuilders // - .fieldSort(order.getProperty()) // - .order(sortOrder); - - if (order.getNullHandling() == Sort.NullHandling.NULLS_FIRST) { - sort.missing("_first"); - } else if (order.getNullHandling() == Sort.NullHandling.NULLS_LAST) { - sort.missing("_last"); - } - - searchRequestBuilder.addSort(sort); - } - } + prepareSort(query, searchRequestBuilder); } if (query.getMinScore() > 0) { @@ -1141,6 +1130,32 @@ public class ElasticsearchTemplate implements ElasticsearchOperations, EsClient< return searchRequestBuilder; } + private void prepareSort(Query query, SearchRequestBuilder searchRequestBuilder) { + for (Sort.Order order : query.getSort()) { + SortOrder sortOrder = order.getDirection().isDescending() ? SortOrder.DESC : SortOrder.ASC; + + if (FIELD_SCORE.equals(order.getProperty())) { + ScoreSortBuilder sort = SortBuilders // + .scoreSort() // + .order(sortOrder); + + searchRequestBuilder.addSort(sort); + } else { + FieldSortBuilder sort = SortBuilders // + .fieldSort(order.getProperty()) // + .order(sortOrder); + + if (order.getNullHandling() == Sort.NullHandling.NULLS_FIRST) { + sort.missing("_first"); + } else if (order.getNullHandling() == Sort.NullHandling.NULLS_LAST) { + sort.missing("_last"); + } + + searchRequestBuilder.addSort(sort); + } + } + } + private IndexRequestBuilder prepareIndex(IndexQuery query) { try { String indexName = StringUtils.isEmpty(query.getIndexName()) diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java index 4e4fdd0f1..5d73aa74e 100755 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java @@ -2717,6 +2717,96 @@ public class ElasticsearchTemplateTests { everyItem(nullValue())); } + @Test //DATAES-457 + public void shouldSortResultsGivenSortCriteriaWithScanAndScroll() { + // given + List indexQueries = new ArrayList<>(); + // first document + String documentId = randomNumeric(5); + SampleEntity sampleEntity1 = SampleEntity.builder().id(documentId).message("abc").rate(10) + .version(System.currentTimeMillis()).build(); + + // second document + String documentId2 = randomNumeric(5); + SampleEntity sampleEntity2 = SampleEntity.builder().id(documentId2).message("xyz").rate(5) + .version(System.currentTimeMillis()).build(); + + // third document + String documentId3 = randomNumeric(5); + SampleEntity sampleEntity3 = SampleEntity.builder().id(documentId3).message("xyz").rate(10) + .version(System.currentTimeMillis()).build(); + + indexQueries = getIndexQueries(Arrays.asList(sampleEntity1, sampleEntity2, sampleEntity3)); + + elasticsearchTemplate.bulkIndex(indexQueries); + elasticsearchTemplate.refresh(SampleEntity.class); + + SearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) + .withSort(new FieldSortBuilder("rate").order(SortOrder.ASC)) + .withSort(new FieldSortBuilder("message").order(SortOrder.DESC)).withPageable(PageRequest.of(0, 10)) + .build(); + // when + ScrolledPage scroll = (ScrolledPage) elasticsearchTemplate + .startScroll(1000, searchQuery, SampleEntity.class); + List sampleEntities = new ArrayList<>(); + while (scroll.hasContent()) { + sampleEntities.addAll(scroll.getContent()); + scroll = (ScrolledPage) elasticsearchTemplate + .continueScroll(scroll.getScrollId(), 1000, SampleEntity.class); + } + // then + assertThat(sampleEntities.size(), equalTo(3)); + assertThat(sampleEntities.get(0).getRate(), is(sampleEntity2.getRate())); + assertThat(sampleEntities.get(1).getRate(), is(sampleEntity3.getRate())); + assertThat(sampleEntities.get(1).getMessage(), is(sampleEntity3.getMessage())); + assertThat(sampleEntities.get(2).getRate(), is(sampleEntity1.getRate())); + assertThat(sampleEntities.get(2).getMessage(), is(sampleEntity1.getMessage())); + } + + @Test //DATAES-457 + public void shouldSortResultsGivenSortCriteriaFromPageableWithScanAndScroll() { + // given + List indexQueries = new ArrayList<>(); + // first document + String documentId = randomNumeric(5); + SampleEntity sampleEntity1 = SampleEntity.builder().id(documentId).message("abc").rate(10) + .version(System.currentTimeMillis()).build(); + + // second document + String documentId2 = randomNumeric(5); + SampleEntity sampleEntity2 = SampleEntity.builder().id(documentId2).message("xyz").rate(5) + .version(System.currentTimeMillis()).build(); + + // third document + String documentId3 = randomNumeric(5); + SampleEntity sampleEntity3 = SampleEntity.builder().id(documentId3).message("xyz").rate(10) + .version(System.currentTimeMillis()).build(); + + indexQueries = getIndexQueries(Arrays.asList(sampleEntity1, sampleEntity2, sampleEntity3)); + + elasticsearchTemplate.bulkIndex(indexQueries); + elasticsearchTemplate.refresh(SampleEntity.class); + + SearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).withPageable( + PageRequest.of(0, 10, Sort.by(Sort.Direction.ASC, "rate").and(Sort.by(Sort.Direction.DESC, "message")))) + .build(); + // when + ScrolledPage scroll = (ScrolledPage) elasticsearchTemplate + .startScroll(1000, searchQuery, SampleEntity.class); + List sampleEntities = new ArrayList<>(); + while (scroll.hasContent()) { + sampleEntities.addAll(scroll.getContent()); + scroll = (ScrolledPage) elasticsearchTemplate + .continueScroll(scroll.getScrollId(), 1000, SampleEntity.class); + } + // then + assertThat(sampleEntities.size(), equalTo(3)); + assertThat(sampleEntities.get(0).getRate(), is(sampleEntity2.getRate())); + assertThat(sampleEntities.get(1).getRate(), is(sampleEntity3.getRate())); + assertThat(sampleEntities.get(1).getMessage(), is(sampleEntity3.getMessage())); + assertThat(sampleEntities.get(2).getRate(), is(sampleEntity1.getRate())); + assertThat(sampleEntities.get(2).getMessage(), is(sampleEntity1.getMessage())); + } private IndexQuery getIndexQuery(SampleEntity sampleEntity) { return new IndexQueryBuilder()