Add pipeline aggregations to NativeSearchQuery.

Original Pull Request #1809 
Closes #1255
This commit is contained in:
Peter-Josef Meisch 2021-05-11 23:21:26 +02:00 committed by GitHub
parent 3a900599f2
commit df0d65eda2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 8 deletions

View File

@ -77,7 +77,6 @@ import org.elasticsearch.index.reindex.UpdateByQueryAction;
import org.elasticsearch.index.reindex.UpdateByQueryRequest;
import org.elasticsearch.index.reindex.UpdateByQueryRequestBuilder;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
@ -1119,9 +1118,11 @@ class RequestFactory {
}
if (!isEmpty(query.getAggregations())) {
for (AbstractAggregationBuilder<?> aggregationBuilder : query.getAggregations()) {
sourceBuilder.aggregation(aggregationBuilder);
}
query.getAggregations().forEach(sourceBuilder::aggregation);
}
if (!isEmpty(query.getPipelineAggregations())) {
query.getPipelineAggregations().forEach(sourceBuilder::aggregation);
}
}
@ -1144,9 +1145,11 @@ class RequestFactory {
}
if (!isEmpty(nativeSearchQuery.getAggregations())) {
for (AbstractAggregationBuilder<?> aggregationBuilder : nativeSearchQuery.getAggregations()) {
searchRequestBuilder.addAggregation(aggregationBuilder);
}
nativeSearchQuery.getAggregations().forEach(searchRequestBuilder::addAggregation);
}
if (!isEmpty(nativeSearchQuery.getPipelineAggregations())) {
nativeSearchQuery.getPipelineAggregations().forEach(searchRequestBuilder::addAggregation);
}
}

View File

@ -22,6 +22,7 @@ import java.util.List;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.script.mustache.SearchTemplateRequestBuilder;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.sort.SortBuilder;
@ -48,6 +49,7 @@ public class NativeSearchQuery extends AbstractQuery {
private final List<ScriptField> scriptFields = new ArrayList<>();
@Nullable private CollapseBuilder collapseBuilder;
@Nullable private List<AbstractAggregationBuilder<?>> aggregations;
@Nullable private List<PipelineAggregationBuilder> pipelineAggregations;
@Nullable private HighlightBuilder highlightBuilder;
@Nullable private HighlightBuilder.Field[] highlightFields;
@Nullable private List<IndexBoost> indicesBoost;
@ -143,6 +145,11 @@ public class NativeSearchQuery extends AbstractQuery {
return aggregations;
}
@Nullable
public List<PipelineAggregationBuilder> getPipelineAggregations() {
return pipelineAggregations;
}
public void addAggregation(AbstractAggregationBuilder<?> aggregationBuilder) {
if (aggregations == null) {
@ -156,6 +163,10 @@ public class NativeSearchQuery extends AbstractQuery {
this.aggregations = aggregations;
}
public void setPipelineAggregations(List<PipelineAggregationBuilder> pipelineAggregationBuilders) {
this.pipelineAggregations = pipelineAggregationBuilders;
}
@Nullable
public List<IndexBoost> getIndicesBoost() {
return indicesBoost;

View File

@ -27,6 +27,7 @@ import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.script.mustache.SearchTemplateRequestBuilder;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.sort.SortBuilder;
@ -55,6 +56,7 @@ public class NativeSearchQueryBuilder {
private final List<ScriptField> scriptFields = new ArrayList<>();
private final List<SortBuilder<?>> sortBuilders = new ArrayList<>();
private final List<AbstractAggregationBuilder<?>> aggregationBuilders = new ArrayList<>();
private final List<PipelineAggregationBuilder> pipelineAggregationBuilders = new ArrayList<>();
@Nullable private HighlightBuilder highlightBuilder;
@Nullable private HighlightBuilder.Field[] highlightFields;
private Pageable pageable = Pageable.unpaged();
@ -105,6 +107,14 @@ public class NativeSearchQueryBuilder {
return this;
}
/**
* @since 4.3
*/
public NativeSearchQueryBuilder addAggregation(PipelineAggregationBuilder pipelineAggregationBuilder) {
this.pipelineAggregationBuilders.add(pipelineAggregationBuilder);
return this;
}
public NativeSearchQueryBuilder withHighlightBuilder(HighlightBuilder highlightBuilder) {
this.highlightBuilder = highlightBuilder;
return this;
@ -239,6 +249,10 @@ public class NativeSearchQueryBuilder {
nativeSearchQuery.setAggregations(aggregationBuilders);
}
if (!isEmpty(pipelineAggregationBuilders)) {
nativeSearchQuery.setPipelineAggregations(pipelineAggregationBuilders);
}
if (minScore > 0) {
nativeSearchQuery.setMinScore(minScore);
}

View File

@ -18,6 +18,7 @@ package org.springframework.data.elasticsearch.core.aggregation;
import static org.assertj.core.api.Assertions.*;
import static org.elasticsearch.index.query.QueryBuilders.*;
import static org.elasticsearch.search.aggregations.AggregationBuilders.*;
import static org.elasticsearch.search.aggregations.PipelineAggregatorBuilders.*;
import static org.springframework.data.elasticsearch.annotations.FieldType.*;
import static org.springframework.data.elasticsearch.annotations.FieldType.Integer;
@ -26,9 +27,14 @@ import java.util.ArrayList;
import java.util.List;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.pipeline.InternalStatsBucket;
import org.elasticsearch.search.aggregations.pipeline.ParsedStatsBucket;
import org.elasticsearch.search.aggregations.pipeline.StatsBucket;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
@ -109,7 +115,7 @@ public class ElasticsearchTemplateAggregationTests {
indexOperations.delete();
}
@Test
@Test // DATAES-96
public void shouldReturnAggregatedResponseForGivenSearchQuery() {
// given
@ -130,6 +136,56 @@ public class ElasticsearchTemplateAggregationTests {
assertThat(searchHits.hasSearchHits()).isFalse();
}
@Test // #1255
@DisplayName("should work with pipeline aggregations")
void shouldWorkWithPipelineAggregations() {
IndexInitializer.init(operations.indexOps(PipelineAggsEntity.class));
operations.save( //
new PipelineAggsEntity("1-1", "one"), //
new PipelineAggsEntity("2-1", "two"), //
new PipelineAggsEntity("2-2", "two"), //
new PipelineAggsEntity("3-1", "three"), //
new PipelineAggsEntity("3-2", "three"), //
new PipelineAggsEntity("3-3", "three") //
); //
NativeSearchQuery searchQuery = new NativeSearchQueryBuilder() //
.withQuery(matchAllQuery()) //
.withSearchType(SearchType.DEFAULT) //
.addAggregation(terms("keyword_aggs").field("keyword")) //
.addAggregation(statsBucket("keyword_bucket_stats", "keyword_aggs._count")) //
.withMaxResults(0) //
.build();
SearchHits<PipelineAggsEntity> searchHits = operations.search(searchQuery, PipelineAggsEntity.class);
Aggregations aggregations = searchHits.getAggregations();
assertThat(aggregations).isNotNull();
assertThat(aggregations.asMap().get("keyword_aggs")).isNotNull();
Aggregation keyword_bucket_stats = aggregations.asMap().get("keyword_bucket_stats");
assertThat(keyword_bucket_stats).isInstanceOf(StatsBucket.class);
if (keyword_bucket_stats instanceof ParsedStatsBucket) {
// Rest client
ParsedStatsBucket statsBucket = (ParsedStatsBucket) keyword_bucket_stats;
assertThat(statsBucket.getMin()).isEqualTo(1.0);
assertThat(statsBucket.getMax()).isEqualTo(3.0);
assertThat(statsBucket.getAvg()).isEqualTo(2.0);
assertThat(statsBucket.getSum()).isEqualTo(6.0);
assertThat(statsBucket.getCount()).isEqualTo(3L);
}
if (keyword_bucket_stats instanceof InternalStatsBucket) {
// transport client
InternalStatsBucket statsBucket = (InternalStatsBucket) keyword_bucket_stats;
assertThat(statsBucket.getMin()).isEqualTo(1.0);
assertThat(statsBucket.getMax()).isEqualTo(3.0);
assertThat(statsBucket.getAvg()).isEqualTo(2.0);
assertThat(statsBucket.getSum()).isEqualTo(6.0);
assertThat(statsBucket.getCount()).isEqualTo(3L);
}
}
// region entities
@Document(indexName = "test-index-articles-core-aggregation")
static class ArticleEntity {
@ -256,4 +312,34 @@ public class ElasticsearchTemplateAggregationTests {
}
}
@Document(indexName = "pipeline-aggs")
static class PipelineAggsEntity {
@Id private String id;
@Field(type = Keyword) private String keyword;
public PipelineAggsEntity() {}
public PipelineAggsEntity(String id, String keyword) {
this.id = id;
this.keyword = keyword;
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getKeyword() {
return keyword;
}
public void setKeyword(String keyword) {
this.keyword = keyword;
}
}
// endregion
}