diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java index 9ab614cbc43..a9c6901d982 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java @@ -158,11 +158,13 @@ import org.elasticsearch.search.aggregations.metrics.ParsedTDigestPercentileRank import org.elasticsearch.search.aggregations.metrics.ParsedTDigestPercentiles; import org.elasticsearch.search.aggregations.metrics.ParsedTopHits; import org.elasticsearch.search.aggregations.metrics.ParsedValueCount; +import org.elasticsearch.search.aggregations.metrics.ParsedWeightedAvg; import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.TopHitsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.ValueCountAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.WeightedAvgAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.DerivativePipelineAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.ExtendedStatsBucketPipelineAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.InternalBucketMetricValue; @@ -1732,6 +1734,7 @@ public class RestHighLevelClient implements Closeable { map.put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c)); map.put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c)); map.put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c)); + map.put(WeightedAvgAggregationBuilder.NAME, (p, c) -> ParsedWeightedAvg.fromXContent(p, (String) c)); map.put(ValueCountAggregationBuilder.NAME, (p, c) -> ParsedValueCount.fromXContent(p, (String) c)); map.put(InternalSimpleValue.NAME, (p, c) -> ParsedSimpleValue.fromXContent(p, (String) c)); map.put(DerivativePipelineAggregationBuilder.NAME, (p, c) -> ParsedDerivative.fromXContent(p, (String) c)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/SearchIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/SearchIT.java index 26b5b286e89..fad42d3c44c 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/SearchIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/SearchIT.java @@ -66,6 +66,9 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.aggregations.matrix.stats.MatrixStats; import org.elasticsearch.search.aggregations.matrix.stats.MatrixStatsAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.WeightedAvg; +import org.elasticsearch.search.aggregations.metrics.WeightedAvgAggregationBuilder; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; import org.elasticsearch.search.aggregations.support.ValueType; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; @@ -371,6 +374,42 @@ public class SearchIT extends ESRestHighLevelClientTestCase { } } + public void testSearchWithTermsAndWeightedAvg() throws IOException { + SearchRequest searchRequest = new SearchRequest("index"); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + TermsAggregationBuilder agg = new TermsAggregationBuilder("agg1", ValueType.STRING).field("type.keyword"); + agg.subAggregation(new WeightedAvgAggregationBuilder("subagg") + .value(new MultiValuesSourceFieldConfig.Builder().setFieldName("num").build()) + .weight(new MultiValuesSourceFieldConfig.Builder().setFieldName("num2").build()) + ); + searchSourceBuilder.aggregation(agg); + searchSourceBuilder.size(0); + searchRequest.source(searchSourceBuilder); + SearchResponse searchResponse = execute(searchRequest, highLevelClient()::search, highLevelClient()::searchAsync); + assertSearchHeader(searchResponse); + assertNull(searchResponse.getSuggest()); + assertEquals(Collections.emptyMap(), searchResponse.getProfileResults()); + assertEquals(0, searchResponse.getHits().getHits().length); + assertEquals(Float.NaN, searchResponse.getHits().getMaxScore(), 0f); + Terms termsAgg = searchResponse.getAggregations().get("agg1"); + assertEquals("agg1", termsAgg.getName()); + assertEquals(2, termsAgg.getBuckets().size()); + Terms.Bucket type1 = termsAgg.getBucketByKey("type1"); + assertEquals(3, type1.getDocCount()); + assertEquals(1, type1.getAggregations().asList().size()); + { + WeightedAvg weightedAvg = type1.getAggregations().get("subagg"); + assertEquals(24.4, weightedAvg.getValue(), 0f); + } + Terms.Bucket type2 = termsAgg.getBucketByKey("type2"); + assertEquals(2, type2.getDocCount()); + assertEquals(1, type2.getAggregations().asList().size()); + { + WeightedAvg weightedAvg = type2.getAggregations().get("subagg"); + assertEquals(100, weightedAvg.getValue(), 0f); + } + } + public void testSearchWithMatrixStats() throws IOException { SearchRequest searchRequest = new SearchRequest("index"); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ParsedWeightedAvg.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ParsedWeightedAvg.java index 984b8509db7..bda50fb79b6 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ParsedWeightedAvg.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ParsedWeightedAvg.java @@ -25,7 +25,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; -class ParsedWeightedAvg extends ParsedSingleValueNumericMetricsAggregation implements WeightedAvg { +public class ParsedWeightedAvg extends ParsedSingleValueNumericMetricsAggregation implements WeightedAvg { @Override public double getValue() { diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/AggregationsTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/AggregationsTests.java index 097a3949fc2..ef001b35fef 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/AggregationsTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/AggregationsTests.java @@ -71,6 +71,7 @@ import org.elasticsearch.search.aggregations.metrics.InternalTDigestPercentilesT import org.elasticsearch.search.aggregations.metrics.InternalScriptedMetricTests; import org.elasticsearch.search.aggregations.metrics.InternalTopHitsTests; import org.elasticsearch.search.aggregations.metrics.InternalValueCountTests; +import org.elasticsearch.search.aggregations.metrics.InternalWeightedAvgTests; import org.elasticsearch.search.aggregations.pipeline.InternalSimpleValueTests; import org.elasticsearch.search.aggregations.pipeline.InternalBucketMetricValueTests; import org.elasticsearch.search.aggregations.pipeline.InternalPercentilesBucketTests; @@ -114,6 +115,7 @@ public class AggregationsTests extends ESTestCase { aggsTests.add(new InternalMinTests()); aggsTests.add(new InternalMaxTests()); aggsTests.add(new InternalAvgTests()); + aggsTests.add(new InternalWeightedAvgTests()); aggsTests.add(new InternalSumTests()); aggsTests.add(new InternalValueCountTests()); aggsTests.add(new InternalSimpleValueTests()); diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/InternalWeightedAvgTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/InternalWeightedAvgTests.java new file mode 100644 index 00000000000..b74eac80496 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/InternalWeightedAvgTests.java @@ -0,0 +1,114 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.metrics; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.elasticsearch.common.io.stream.Writeable.Reader; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.ParsedAggregation; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.test.InternalAggregationTestCase; + +public class InternalWeightedAvgTests extends InternalAggregationTestCase { + + @Override + protected InternalWeightedAvg createTestInstance( + String name, + List pipelineAggregators, + Map metaData + ) { + DocValueFormat formatter = randomNumericDocValueFormat(); + return new InternalWeightedAvg( + name, + randomDoubleBetween(0, 100000, true), + randomDoubleBetween(0, 100000, true), + formatter, pipelineAggregators, metaData); + } + + @Override + protected Reader instanceReader() { + return InternalWeightedAvg::new; + } + + @Override + protected void assertReduced(InternalWeightedAvg reduced, List inputs) { + double sum = 0; + double weight = 0; + for (InternalWeightedAvg in : inputs) { + sum += in.getSum(); + weight += in.getWeight(); + } + assertEquals(sum, reduced.getSum(), 0.0000001); + assertEquals(weight, reduced.getWeight(), 0.0000001); + assertEquals(sum / weight, reduced.getValue(), 0.0000001); + } + + @Override + protected void assertFromXContent(InternalWeightedAvg avg, ParsedAggregation parsedAggregation) { + ParsedWeightedAvg parsed = ((ParsedWeightedAvg) parsedAggregation); + assertEquals(avg.getValue(), parsed.getValue(), Double.MIN_VALUE); + // we don't print out VALUE_AS_STRING for avg.getCount() == 0, so we cannot get the exact same value back + if (avg.getWeight() != 0) { + assertEquals(avg.getValueAsString(), parsed.getValueAsString()); + } + } + + @Override + protected InternalWeightedAvg mutateInstance(InternalWeightedAvg instance) { + String name = instance.getName(); + double sum = instance.getSum(); + double weight = instance.getWeight(); + DocValueFormat formatter = instance.getFormatter(); + List pipelineAggregators = instance.pipelineAggregators(); + Map metaData = instance.getMetaData(); + switch (between(0, 2)) { + case 0: + name += randomAlphaOfLength(5); + break; + case 1: + if (Double.isFinite(sum)) { + sum += between(1, 100); + } else { + sum = between(1, 100); + } + break; + case 2: + if (Double.isFinite(weight)) { + weight += between(1, 100); + } else { + weight = between(1, 100); + } + break; + case 3: + if (metaData == null) { + metaData = new HashMap<>(1); + } else { + metaData = new HashMap<>(instance.getMetaData()); + } + metaData.put(randomAlphaOfLength(15), randomInt()); + break; + default: + throw new AssertionError("Illegal randomisation branch"); + } + return new InternalWeightedAvg(name, sum, weight, formatter, pipelineAggregators, metaData); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java index 59327121c90..551110ca252 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java @@ -114,11 +114,13 @@ import org.elasticsearch.search.aggregations.metrics.ParsedTDigestPercentileRank import org.elasticsearch.search.aggregations.metrics.ParsedTDigestPercentiles; import org.elasticsearch.search.aggregations.metrics.ParsedTopHits; import org.elasticsearch.search.aggregations.metrics.ParsedValueCount; +import org.elasticsearch.search.aggregations.metrics.ParsedWeightedAvg; import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.TopHitsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.ValueCountAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.WeightedAvgAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.DerivativePipelineAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.ExtendedStatsBucketPipelineAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.InternalBucketMetricValue; @@ -186,6 +188,7 @@ public abstract class InternalAggregationTestCase map.put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c)); map.put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c)); map.put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c)); + map.put(WeightedAvgAggregationBuilder.NAME, (p, c) -> ParsedWeightedAvg.fromXContent(p, (String) c)); map.put(ValueCountAggregationBuilder.NAME, (p, c) -> ParsedValueCount.fromXContent(p, (String) c)); map.put(InternalSimpleValue.NAME, (p, c) -> ParsedSimpleValue.fromXContent(p, (String) c)); map.put(DerivativePipelineAggregationBuilder.NAME, (p, c) -> ParsedDerivative.fromXContent(p, (String) c));