Add WeightedAvg metric aggregation (#31037)

Adds a new single-value metrics aggregation that computes the weighted 
average of numeric values that are extracted from the aggregated 
documents. These values can be extracted from specific numeric
fields in the documents.

When calculating a regular average, each datapoint has an equal "weight"; it
contributes equally to the final value.  In contrast, weighted averages
scale each datapoint differently.  The amount that each datapoint contributes 
to the final value is extracted from the document, or provided by a script.

As a formula, a weighted average is the `∑(value * weight) / ∑(weight)`

A regular average can be thought of as a weighted average where every value has
an implicit weight of `1`.

Closes #15731
This commit is contained in:
Zachary Tong 2018-07-23 18:33:15 -04:00 committed by GitHub
parent 55a2d3e0dd
commit 6ba144ae31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 2117 additions and 68 deletions

View File

@ -379,9 +379,9 @@ buildRestTests.setups['exams'] = '''
refresh: true
body: |
{"index":{}}
{"grade": 100}
{"grade": 100, "weight": 2}
{"index":{}}
{"grade": 50}'''
{"grade": 50, "weight": 3}'''
buildRestTests.setups['stored_example_script'] = '''
# Simple script to load a field. Not really a good example, but a simple one.

View File

@ -13,6 +13,8 @@ bucket aggregations (some bucket aggregations enable you to sort the returned bu
include::metrics/avg-aggregation.asciidoc[]
include::metrics/weighted-avg-aggregation.asciidoc[]
include::metrics/cardinality-aggregation.asciidoc[]
include::metrics/extendedstats-aggregation.asciidoc[]

View File

@ -0,0 +1,202 @@
[[search-aggregations-metrics-weight-avg-aggregation]]
=== Weighted Avg Aggregation
A `single-value` metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents.
These values can be extracted either from specific numeric fields in the documents.
When calculating a regular average, each datapoint has an equal "weight" ... it contributes equally to the final value. Weighted averages,
on the other hand, weight each datapoint differently. The amount that each datapoint contributes to the final value is extracted from the
document, or provided by a script.
As a formula, a weighted average is the `∑(value * weight) / ∑(weight)`
A regular average can be thought of as a weighted average where every value has an implicit weight of `1`.
.`weighted_avg` Parameters
|===
|Parameter Name |Description |Required |Default Value
|`value` | The configuration for the field or script that provides the values |Required |
|`weight` | The configuration for the field or script that provides the weights |Required |
|`format` | The numeric response formatter |Optional |
|`value_type` | A hint about the values for pure scripts or unmapped fields |Optional |
|===
The `value` and `weight` objects have per-field specific configuration:
.`value` Parameters
|===
|Parameter Name |Description |Required |Default Value
|`field` | The field that values should be extracted from |Required |
|`missing` | A value to use if the field is missing entirely |Optional |
|`script` | A script which provides the values for the document. This is mutually exclusive with `field` |Optional
|===
.`weight` Parameters
|===
|Parameter Name |Description |Required |Default Value
|`field` | The field that weights should be extracted from |Required |
|`missing` | A weight to use if the field is missing entirely |Optional |
|`script` | A script which provides the weights for the document. This is mutually exclusive with `field` |Optional
|===
==== Examples
If our documents have a `"grade"` field that holds a 0-100 numeric score, and a `"weight"` field which holds an arbitrary numeric weight,
we can calculate the weighted average using:
[source,js]
--------------------------------------------------
POST /exams/_search
{
"size": 0,
"aggs" : {
"weighted_grade": {
"weighted_avg": {
"value": {
"field": "grade"
},
"weight": {
"field": "weight"
}
}
}
}
}
--------------------------------------------------
// CONSOLE
// TEST[setup:exams]
Which yields a response like:
[source,js]
--------------------------------------------------
{
...
"aggregations": {
"weighted_grade": {
"value": 70.0
}
}
}
--------------------------------------------------
// TESTRESPONSE[s/\.\.\./"took": $body.took,"timed_out": false,"_shards": $body._shards,"hits": $body.hits,/]
While multiple values-per-field are allowed, only one weight is allowed. If the aggregation encounters
a document that has more than one weight (e.g. the weight field is a multi-valued field) it will throw an exception.
If you have this situation, you will need to specify a `script` for the weight field, and use the script
to combine the multiple values into a single value to be used.
This single weight will be applied independently to each value extracted from the `value` field.
This example show how a single document with multiple values will be averaged with a single weight:
[source,js]
--------------------------------------------------
POST /exams/_doc?refresh
{
"grade": [1, 2, 3],
"weight": 2
}
POST /exams/_search
{
"size": 0,
"aggs" : {
"weighted_grade": {
"weighted_avg": {
"value": {
"field": "grade"
},
"weight": {
"field": "weight"
}
}
}
}
}
--------------------------------------------------
// CONSOLE
// TEST
The three values (`1`, `2`, and `3`) will be included as independent values, all with the weight of `2`:
[source,js]
--------------------------------------------------
{
...
"aggregations": {
"weighted_grade": {
"value": 2.0
}
}
}
--------------------------------------------------
// TESTRESPONSE[s/\.\.\./"took": $body.took,"timed_out": false,"_shards": $body._shards,"hits": $body.hits,/]
The aggregation returns `2.0` as the result, which matches what we would expect when calculating by hand:
`((1*2) + (2*2) + (3*2)) / (2+2+2) == 2`
==== Script
Both the value and the weight can be derived from a script, instead of a field. As a simple example, the following
will add one to the grade and weight in the document using a script:
[source,js]
--------------------------------------------------
POST /exams/_search
{
"size": 0,
"aggs" : {
"weighted_grade": {
"weighted_avg": {
"value": {
"script": "doc.grade.value + 1"
},
"weight": {
"script": "doc.weight.value + 1"
}
}
}
}
}
--------------------------------------------------
// CONSOLE
// TEST[setup:exams]
==== Missing values
The `missing` parameter defines how documents that are missing a value should be treated.
The default behavior is different for `value` and `weight`:
By default, if the `value` field is missing the document is ignored and the aggregation moves on to the next document.
If the `weight` field is missing, it is assumed to have a weight of `1` (like a normal average).
Both of these defaults can be overridden with the `missing` parameter:
[source,js]
--------------------------------------------------
POST /exams/_search
{
"size": 0,
"aggs" : {
"weighted_grade": {
"weighted_avg": {
"value": {
"field": "grade",
"missing": 2
},
"weight": {
"field": "weight",
"missing": 3
}
}
}
}
}
--------------------------------------------------
// CONSOLE
// TEST[setup:exams]

View File

@ -26,7 +26,7 @@ import org.elasticsearch.search.MultiValueMode;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ArrayValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValueType;
import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
@ -38,7 +38,7 @@ import java.io.IOException;
import java.util.Map;
public class MatrixStatsAggregationBuilder
extends MultiValuesSourceAggregationBuilder.LeafOnly<ValuesSource.Numeric, MatrixStatsAggregationBuilder> {
extends ArrayValuesSourceAggregationBuilder.LeafOnly<ValuesSource.Numeric, MatrixStatsAggregationBuilder> {
public static final String NAME = "matrix_stats";
private MultiValueMode multiValueMode = MultiValueMode.AVG;

View File

@ -30,7 +30,7 @@ import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
import org.elasticsearch.search.aggregations.metrics.MetricsAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.MultiValuesSource.NumericMultiValuesSource;
import org.elasticsearch.search.aggregations.support.ArrayValuesSource.NumericArrayValuesSource;
import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.internal.SearchContext;
@ -43,7 +43,7 @@ import java.util.Map;
**/
final class MatrixStatsAggregator extends MetricsAggregator {
/** Multiple ValuesSource with field names */
private final NumericMultiValuesSource valuesSources;
private final NumericArrayValuesSource valuesSources;
/** array of descriptive stats, per shard, needed to compute the correlation */
ObjectArray<RunningStats> stats;
@ -53,7 +53,7 @@ final class MatrixStatsAggregator extends MetricsAggregator {
Map<String,Object> metaData) throws IOException {
super(name, context, parent, pipelineAggregators, metaData);
if (valuesSources != null && !valuesSources.isEmpty()) {
this.valuesSources = new NumericMultiValuesSource(valuesSources, multiValueMode);
this.valuesSources = new NumericArrayValuesSource(valuesSources, multiValueMode);
stats = context.bigArrays().newObjectArray(1);
} else {
this.valuesSources = null;

View File

@ -23,7 +23,7 @@ import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.ArrayValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.internal.SearchContext;
@ -33,7 +33,7 @@ import java.util.List;
import java.util.Map;
final class MatrixStatsAggregatorFactory
extends MultiValuesSourceAggregatorFactory<ValuesSource.Numeric, MatrixStatsAggregatorFactory> {
extends ArrayValuesSourceAggregatorFactory<ValuesSource.Numeric, MatrixStatsAggregatorFactory> {
private final MultiValueMode multiValueMode;

View File

@ -21,14 +21,14 @@ package org.elasticsearch.search.aggregations.matrix.stats;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.MultiValueMode;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceParser.NumericValuesSourceParser;
import org.elasticsearch.search.aggregations.support.ArrayValuesSourceParser.NumericValuesSourceParser;
import org.elasticsearch.search.aggregations.support.ValueType;
import org.elasticsearch.search.aggregations.support.ValuesSourceType;
import java.io.IOException;
import java.util.Map;
import static org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder.MULTIVALUE_MODE_FIELD;
import static org.elasticsearch.search.aggregations.support.ArrayValuesSourceAggregationBuilder.MULTIVALUE_MODE_FIELD;
public class MatrixStatsParser extends NumericValuesSourceParser {

View File

@ -28,13 +28,13 @@ import java.util.Map;
/**
* Class to encapsulate a set of ValuesSource objects labeled by field name
*/
public abstract class MultiValuesSource <VS extends ValuesSource> {
public abstract class ArrayValuesSource<VS extends ValuesSource> {
protected MultiValueMode multiValueMode;
protected String[] names;
protected VS[] values;
public static class NumericMultiValuesSource extends MultiValuesSource<ValuesSource.Numeric> {
public NumericMultiValuesSource(Map<String, ValuesSource.Numeric> valuesSources, MultiValueMode multiValueMode) {
public static class NumericArrayValuesSource extends ArrayValuesSource<ValuesSource.Numeric> {
public NumericArrayValuesSource(Map<String, ValuesSource.Numeric> valuesSources, MultiValueMode multiValueMode) {
super(valuesSources, multiValueMode);
if (valuesSources != null) {
this.values = valuesSources.values().toArray(new ValuesSource.Numeric[0]);
@ -51,8 +51,8 @@ public abstract class MultiValuesSource <VS extends ValuesSource> {
}
}
public static class BytesMultiValuesSource extends MultiValuesSource<ValuesSource.Bytes> {
public BytesMultiValuesSource(Map<String, ValuesSource.Bytes> valuesSources, MultiValueMode multiValueMode) {
public static class BytesArrayValuesSource extends ArrayValuesSource<ValuesSource.Bytes> {
public BytesArrayValuesSource(Map<String, ValuesSource.Bytes> valuesSources, MultiValueMode multiValueMode) {
super(valuesSources, multiValueMode);
this.values = valuesSources.values().toArray(new ValuesSource.Bytes[0]);
}
@ -62,14 +62,14 @@ public abstract class MultiValuesSource <VS extends ValuesSource> {
}
}
public static class GeoPointValuesSource extends MultiValuesSource<ValuesSource.GeoPoint> {
public static class GeoPointValuesSource extends ArrayValuesSource<ValuesSource.GeoPoint> {
public GeoPointValuesSource(Map<String, ValuesSource.GeoPoint> valuesSources, MultiValueMode multiValueMode) {
super(valuesSources, multiValueMode);
this.values = valuesSources.values().toArray(new ValuesSource.GeoPoint[0]);
}
}
private MultiValuesSource(Map<String, ?> valuesSources, MultiValueMode multiValueMode) {
private ArrayValuesSource(Map<String, ?> valuesSources, MultiValueMode multiValueMode) {
if (valuesSources != null) {
this.names = valuesSources.keySet().toArray(new String[0]);
}

View File

@ -44,13 +44,13 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSource, AB extends MultiValuesSourceAggregationBuilder<VS, AB>>
extends AbstractAggregationBuilder<AB> {
public abstract class ArrayValuesSourceAggregationBuilder<VS extends ValuesSource, AB extends ArrayValuesSourceAggregationBuilder<VS, AB>>
extends AbstractAggregationBuilder<AB> {
public static final ParseField MULTIVALUE_MODE_FIELD = new ParseField("mode");
public abstract static class LeafOnly<VS extends ValuesSource, AB extends MultiValuesSourceAggregationBuilder<VS, AB>>
extends MultiValuesSourceAggregationBuilder<VS, AB> {
public abstract static class LeafOnly<VS extends ValuesSource, AB extends ArrayValuesSourceAggregationBuilder<VS, AB>>
extends ArrayValuesSourceAggregationBuilder<VS, AB> {
protected LeafOnly(String name, ValuesSourceType valuesSourceType, ValueType targetValueType) {
super(name, valuesSourceType, targetValueType);
@ -94,7 +94,7 @@ public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSourc
private Object missing = null;
private Map<String, Object> missingMap = Collections.emptyMap();
protected MultiValuesSourceAggregationBuilder(String name, ValuesSourceType valuesSourceType, ValueType targetValueType) {
protected ArrayValuesSourceAggregationBuilder(String name, ValuesSourceType valuesSourceType, ValueType targetValueType) {
super(name);
if (valuesSourceType == null) {
throw new IllegalArgumentException("[valuesSourceType] must not be null: [" + name + "]");
@ -103,7 +103,7 @@ public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSourc
this.targetValueType = targetValueType;
}
protected MultiValuesSourceAggregationBuilder(MultiValuesSourceAggregationBuilder<VS, AB> clone,
protected ArrayValuesSourceAggregationBuilder(ArrayValuesSourceAggregationBuilder<VS, AB> clone,
Builder factoriesBuilder, Map<String, Object> metaData) {
super(clone, factoriesBuilder, metaData);
this.valuesSourceType = clone.valuesSourceType;
@ -115,7 +115,7 @@ public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSourc
this.missing = clone.missing;
}
protected MultiValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType, ValueType targetValueType)
protected ArrayValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType, ValueType targetValueType)
throws IOException {
super(in);
assert false == serializeTargetValueType() : "Wrong read constructor called for subclass that provides its targetValueType";
@ -124,7 +124,7 @@ public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSourc
read(in);
}
protected MultiValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType) throws IOException {
protected ArrayValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType) throws IOException {
super(in);
assert serializeTargetValueType() : "Wrong read constructor called for subclass that serializes its targetValueType";
this.valuesSourceType = valuesSourceType;
@ -239,10 +239,10 @@ public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSourc
}
@Override
protected final MultiValuesSourceAggregatorFactory<VS, ?> doBuild(SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException {
protected final ArrayValuesSourceAggregatorFactory<VS, ?> doBuild(SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException {
Map<String, ValuesSourceConfig<VS>> configs = resolveConfig(context);
MultiValuesSourceAggregatorFactory<VS, ?> factory = innerBuild(context, configs, parent, subFactoriesBuilder);
ArrayValuesSourceAggregatorFactory<VS, ?> factory = innerBuild(context, configs, parent, subFactoriesBuilder);
return factory;
}
@ -255,9 +255,10 @@ public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSourc
return configs;
}
protected abstract MultiValuesSourceAggregatorFactory<VS, ?> innerBuild(SearchContext context,
Map<String, ValuesSourceConfig<VS>> configs, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException;
protected abstract ArrayValuesSourceAggregatorFactory<VS, ?> innerBuild(SearchContext context,
Map<String, ValuesSourceConfig<VS>> configs,
AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException;
public ValuesSourceConfig<VS> config(SearchContext context, String field, Script script) {
@ -355,14 +356,14 @@ public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSourc
@Override
protected final int doHashCode() {
return Objects.hash(fields, format, missing, targetValueType, valueType, valuesSourceType,
innerHashCode());
innerHashCode());
}
protected abstract int innerHashCode();
@Override
protected final boolean doEquals(Object obj) {
MultiValuesSourceAggregationBuilder<?, ?> other = (MultiValuesSourceAggregationBuilder<?, ?>) obj;
ArrayValuesSourceAggregationBuilder<?, ?> other = (ArrayValuesSourceAggregationBuilder<?, ?>) obj;
if (!Objects.equals(fields, other.fields))
return false;
if (!Objects.equals(format, other.format))

View File

@ -30,14 +30,15 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
public abstract class MultiValuesSourceAggregatorFactory<VS extends ValuesSource, AF extends MultiValuesSourceAggregatorFactory<VS, AF>>
extends AggregatorFactory<AF> {
public abstract class ArrayValuesSourceAggregatorFactory<VS extends ValuesSource, AF extends ArrayValuesSourceAggregatorFactory<VS, AF>>
extends AggregatorFactory<AF> {
protected Map<String, ValuesSourceConfig<VS>> configs;
public MultiValuesSourceAggregatorFactory(String name, Map<String, ValuesSourceConfig<VS>> configs,
SearchContext context, AggregatorFactory<?> parent, AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
public ArrayValuesSourceAggregatorFactory(String name, Map<String, ValuesSourceConfig<VS>> configs,
SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
super(name, context, parent, subFactoriesBuilder, metaData);
this.configs = configs;
}
@ -63,6 +64,7 @@ public abstract class MultiValuesSourceAggregatorFactory<VS extends ValuesSource
Map<String, Object> metaData) throws IOException;
protected abstract Aggregator doCreateInternal(Map<String, VS> valuesSources, Aggregator parent, boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException;
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException;
}

View File

@ -33,30 +33,30 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
public abstract class MultiValuesSourceParser<VS extends ValuesSource> implements Aggregator.Parser {
public abstract class ArrayValuesSourceParser<VS extends ValuesSource> implements Aggregator.Parser {
public abstract static class AnyValuesSourceParser extends MultiValuesSourceParser<ValuesSource> {
public abstract static class AnyValuesSourceParser extends ArrayValuesSourceParser<ValuesSource> {
protected AnyValuesSourceParser(boolean formattable) {
super(formattable, ValuesSourceType.ANY, null);
}
}
public abstract static class NumericValuesSourceParser extends MultiValuesSourceParser<ValuesSource.Numeric> {
public abstract static class NumericValuesSourceParser extends ArrayValuesSourceParser<ValuesSource.Numeric> {
protected NumericValuesSourceParser(boolean formattable) {
super(formattable, ValuesSourceType.NUMERIC, ValueType.NUMERIC);
}
}
public abstract static class BytesValuesSourceParser extends MultiValuesSourceParser<ValuesSource.Bytes> {
public abstract static class BytesValuesSourceParser extends ArrayValuesSourceParser<ValuesSource.Bytes> {
protected BytesValuesSourceParser(boolean formattable) {
super(formattable, ValuesSourceType.BYTES, ValueType.STRING);
}
}
public abstract static class GeoPointValuesSourceParser extends MultiValuesSourceParser<ValuesSource.GeoPoint> {
public abstract static class GeoPointValuesSourceParser extends ArrayValuesSourceParser<ValuesSource.GeoPoint> {
protected GeoPointValuesSourceParser(boolean formattable) {
super(formattable, ValuesSourceType.GEOPOINT, ValueType.GEOPOINT);
@ -67,15 +67,15 @@ public abstract class MultiValuesSourceParser<VS extends ValuesSource> implement
private ValuesSourceType valuesSourceType = null;
private ValueType targetValueType = null;
private MultiValuesSourceParser(boolean formattable, ValuesSourceType valuesSourceType, ValueType targetValueType) {
private ArrayValuesSourceParser(boolean formattable, ValuesSourceType valuesSourceType, ValueType targetValueType) {
this.valuesSourceType = valuesSourceType;
this.targetValueType = targetValueType;
this.formattable = formattable;
}
@Override
public final MultiValuesSourceAggregationBuilder<VS, ?> parse(String aggregationName, XContentParser parser)
throws IOException {
public final ArrayValuesSourceAggregationBuilder<VS, ?> parse(String aggregationName, XContentParser parser)
throws IOException {
List<String> fields = null;
ValueType valueType = null;
@ -98,7 +98,7 @@ public abstract class MultiValuesSourceParser<VS extends ValuesSource> implement
"Multi-field aggregations do not support scripts.");
} else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) {
throw new ParsingException(parser.getTokenLocation(),
"Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "].");
"Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "].");
}
} else if (token == XContentParser.Token.START_OBJECT) {
if (CommonFields.MISSING.match(currentFieldName, parser.getDeprecationHandler())) {
@ -113,7 +113,7 @@ public abstract class MultiValuesSourceParser<VS extends ValuesSource> implement
} else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) {
throw new ParsingException(parser.getTokenLocation(),
"Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "].");
"Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "].");
}
} else if (token == XContentParser.Token.START_ARRAY) {
if (Script.SCRIPT_PARSE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
@ -127,21 +127,21 @@ public abstract class MultiValuesSourceParser<VS extends ValuesSource> implement
fields.add(parser.text());
} else {
throw new ParsingException(parser.getTokenLocation(),
"Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "].");
"Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "].");
}
}
} else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) {
throw new ParsingException(parser.getTokenLocation(),
"Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "].");
"Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "].");
}
} else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) {
throw new ParsingException(parser.getTokenLocation(),
"Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "].");
"Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "].");
}
}
MultiValuesSourceAggregationBuilder<VS, ?> factory = createFactory(aggregationName, this.valuesSourceType, this.targetValueType,
otherOptions);
ArrayValuesSourceAggregationBuilder<VS, ?> factory = createFactory(aggregationName, this.valuesSourceType, this.targetValueType,
otherOptions);
if (fields != null) {
factory.fields(fields);
}
@ -182,7 +182,7 @@ public abstract class MultiValuesSourceParser<VS extends ValuesSource> implement
/**
* Creates a {@link ValuesSourceAggregationBuilder} from the information
* gathered by the subclass. Options parsed in
* {@link MultiValuesSourceParser} itself will be added to the factory
* {@link ArrayValuesSourceParser} itself will be added to the factory
* after it has been returned by this method.
*
* @param aggregationName
@ -197,11 +197,13 @@ public abstract class MultiValuesSourceParser<VS extends ValuesSource> implement
* method
* @return the created factory
*/
protected abstract MultiValuesSourceAggregationBuilder<VS, ?> createFactory(String aggregationName, ValuesSourceType valuesSourceType,
ValueType targetValueType, Map<ParseField, Object> otherOptions);
protected abstract ArrayValuesSourceAggregationBuilder<VS, ?> createFactory(String aggregationName,
ValuesSourceType valuesSourceType,
ValueType targetValueType,
Map<ParseField, Object> otherOptions);
/**
* Allows subclasses of {@link MultiValuesSourceParser} to parse extra
* Allows subclasses of {@link ArrayValuesSourceParser} to parse extra
* parameters and store them in a {@link Map} which will later be passed to
* {@link #createFactory(String, ValuesSourceType, ValueType, Map)}.
*

View File

@ -0,0 +1,71 @@
setup:
- do:
indices.create:
index: test_1
body:
settings:
number_of_replicas: 0
mappings:
doc:
properties:
int_field:
type : integer
double_field:
type : double
string_field:
type: keyword
- do:
bulk:
refresh: true
body:
- index:
_index: test_1
_type: doc
_id: 1
- int_field: 1
double_field: 1.0
- index:
_index: test_1
_type: doc
_id: 2
- int_field: 2
double_field: 2.0
- index:
_index: test_1
_type: doc
_id: 3
- int_field: 3
double_field: 3.0
- index:
_index: test_1
_type: doc
_id: 4
- int_field: 4
double_field: 4.0
---
"Basic test":
- do:
search:
body:
aggs:
the_int_avg:
weighted_avg:
value:
field: "int_field"
weight:
field: "int_field"
the_double_avg:
weighted_avg:
value:
field: "double_field"
weight:
field: "double_field"
- match: { hits.total: 4 }
- length: { hits.hits: 4 }
- match: { aggregations.the_int_avg.value: 3.0 }
- match: { aggregations.the_double_avg.value: 3.0 }

View File

@ -181,6 +181,8 @@ import org.elasticsearch.search.aggregations.metrics.tophits.InternalTopHits;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.valuecount.InternalValueCount;
import org.elasticsearch.search.aggregations.metrics.valuecount.ValueCountAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.weighted_avg.InternalWeightedAvg;
import org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.InternalSimpleValue;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.pipeline.bucketmetrics.InternalBucketMetricValue;
@ -335,6 +337,8 @@ public class SearchModule {
private void registerAggregations(List<SearchPlugin> plugins) {
registerAggregation(new AggregationSpec(AvgAggregationBuilder.NAME, AvgAggregationBuilder::new, AvgAggregationBuilder::parse)
.addResultReader(InternalAvg::new));
registerAggregation(new AggregationSpec(WeightedAvgAggregationBuilder.NAME, WeightedAvgAggregationBuilder::new,
WeightedAvgAggregationBuilder::parse).addResultReader(InternalWeightedAvg::new));
registerAggregation(new AggregationSpec(SumAggregationBuilder.NAME, SumAggregationBuilder::new, SumAggregationBuilder::parse)
.addResultReader(InternalSum::new));
registerAggregation(new AggregationSpec(MinAggregationBuilder.NAME, MinAggregationBuilder::new, MinAggregationBuilder::parse)

View File

@ -82,6 +82,7 @@ import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.valuecount.ValueCount;
import org.elasticsearch.search.aggregations.metrics.valuecount.ValueCountAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder;
import java.util.Map;
@ -107,6 +108,13 @@ public class AggregationBuilders {
return new AvgAggregationBuilder(name);
}
/**
* Create a new {@link Avg} aggregation with the given name.
*/
public static WeightedAvgAggregationBuilder weightedAvg(String name) {
return new WeightedAvgAggregationBuilder(name);
}
/**
* Create a new {@link Max} aggregation with the given name.
*/

View File

@ -0,0 +1,144 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics.weighted_avg;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class InternalWeightedAvg extends InternalNumericMetricsAggregation.SingleValue implements WeightedAvg {
private final double sum;
private final double weight;
public InternalWeightedAvg(String name, double sum, double weight, DocValueFormat format, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) {
super(name, pipelineAggregators, metaData);
this.sum = sum;
this.weight = weight;
this.format = format;
}
/**
* Read from a stream.
*/
public InternalWeightedAvg(StreamInput in) throws IOException {
super(in);
format = in.readNamedWriteable(DocValueFormat.class);
sum = in.readDouble();
weight = in.readDouble();
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(format);
out.writeDouble(sum);
out.writeDouble(weight);
}
@Override
public double value() {
return getValue();
}
@Override
public double getValue() {
return sum / weight;
}
double getSum() {
return sum;
}
double getWeight() {
return weight;
}
DocValueFormat getFormatter() {
return format;
}
@Override
public String getWriteableName() {
return WeightedAvgAggregationBuilder.NAME;
}
@Override
public InternalWeightedAvg doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
double weight = 0;
double sum = 0;
double sumCompensation = 0;
double weightCompensation = 0;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
for (InternalAggregation aggregation : aggregations) {
InternalWeightedAvg avg = (InternalWeightedAvg) aggregation;
// If the weight is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
if (Double.isFinite(avg.weight) == false) {
weight += avg.weight;
} else if (Double.isFinite(weight)) {
double corrected = avg.weight - weightCompensation;
double newWeight = weight + corrected;
weightCompensation = (newWeight - weight) - corrected;
weight = newWeight;
}
// If the avg is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
if (Double.isFinite(avg.sum) == false) {
sum += avg.sum;
} else if (Double.isFinite(sum)) {
double corrected = avg.sum - sumCompensation;
double newSum = sum + corrected;
sumCompensation = (newSum - sum) - corrected;
sum = newSum;
}
}
return new InternalWeightedAvg(getName(), sum, weight, format, pipelineAggregators(), getMetaData());
}
@Override
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(CommonFields.VALUE.getPreferredName(), weight != 0 ? getValue() : null);
if (weight != 0 && format != DocValueFormat.RAW) {
builder.field(CommonFields.VALUE_AS_STRING.getPreferredName(), format.format(getValue()));
}
return builder;
}
@Override
protected int doHashCode() {
return Objects.hash(sum, weight, format.getWriteableName());
}
@Override
protected boolean doEquals(Object obj) {
InternalWeightedAvg other = (InternalWeightedAvg) obj;
return Objects.equals(sum, other.sum) &&
Objects.equals(weight, other.weight) &&
Objects.equals(format.getWriteableName(), other.format.getWriteableName());
}
}

View File

@ -0,0 +1,65 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics.weighted_avg;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.metrics.ParsedSingleValueNumericMetricsAggregation;
import java.io.IOException;
public class ParsedWeightedAvg extends ParsedSingleValueNumericMetricsAggregation implements WeightedAvg {
@Override
public double getValue() {
return value();
}
@Override
public String getType() {
return WeightedAvgAggregationBuilder.NAME;
}
@Override
protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
// InternalWeightedAvg renders value only if the avg normalizer (count) is not 0.
// We parse back `null` as Double.POSITIVE_INFINITY so we check for that value here to get the same xContent output
boolean hasValue = value != Double.POSITIVE_INFINITY;
builder.field(CommonFields.VALUE.getPreferredName(), hasValue ? value : null);
if (hasValue && valueAsString != null) {
builder.field(CommonFields.VALUE_AS_STRING.getPreferredName(), valueAsString);
}
return builder;
}
private static final ObjectParser<ParsedWeightedAvg, Void> PARSER
= new ObjectParser<>(ParsedWeightedAvg.class.getSimpleName(), true, ParsedWeightedAvg::new);
static {
declareSingleValueFields(PARSER, Double.POSITIVE_INFINITY);
}
public static ParsedWeightedAvg fromXContent(XContentParser parser, final String name) {
ParsedWeightedAvg avg = PARSER.apply(parser, null);
avg.setName(name);
return avg;
}
}

View File

@ -0,0 +1,32 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics.weighted_avg;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
/**
* An aggregation that computes the average of the values in the current bucket.
*/
public interface WeightedAvg extends NumericMetricsAggregation.SingleValue {
/**
* The average value.
*/
double getValue();
}

View File

@ -0,0 +1,128 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics.weighted_avg;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories.Builder;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceParseHelper;
import org.elasticsearch.search.aggregations.support.ValueType;
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationBuilder.LeafOnly<Numeric, WeightedAvgAggregationBuilder> {
public static final String NAME = "weighted_avg";
public static final ParseField VALUE_FIELD = new ParseField("value");
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
private static final ObjectParser<WeightedAvgAggregationBuilder, Void> PARSER;
static {
PARSER = new ObjectParser<>(WeightedAvgAggregationBuilder.NAME);
MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC);
MultiValuesSourceParseHelper.declareField(VALUE_FIELD.getPreferredName(), PARSER, true, false);
MultiValuesSourceParseHelper.declareField(WEIGHT_FIELD.getPreferredName(), PARSER, true, false);
}
public static AggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException {
return PARSER.parse(parser, new WeightedAvgAggregationBuilder(aggregationName), null);
}
public WeightedAvgAggregationBuilder(String name) {
super(name, ValueType.NUMERIC);
}
public WeightedAvgAggregationBuilder(WeightedAvgAggregationBuilder clone, Builder factoriesBuilder, Map<String, Object> metaData) {
super(clone, factoriesBuilder, metaData);
}
public WeightedAvgAggregationBuilder value(MultiValuesSourceFieldConfig valueConfig) {
valueConfig = Objects.requireNonNull(valueConfig, "Configuration for field [" + VALUE_FIELD + "] cannot be null");
field(VALUE_FIELD.getPreferredName(), valueConfig);
return this;
}
public WeightedAvgAggregationBuilder weight(MultiValuesSourceFieldConfig weightConfig) {
weightConfig = Objects.requireNonNull(weightConfig, "Configuration for field [" + WEIGHT_FIELD + "] cannot be null");
field(WEIGHT_FIELD.getPreferredName(), weightConfig);
return this;
}
/**
* Read from a stream.
*/
public WeightedAvgAggregationBuilder(StreamInput in) throws IOException {
super(in, ValueType.NUMERIC);
}
@Override
protected AggregationBuilder shallowCopy(Builder factoriesBuilder, Map<String, Object> metaData) {
return new WeightedAvgAggregationBuilder(this, factoriesBuilder, metaData);
}
@Override
protected void innerWriteTo(StreamOutput out) {
// Do nothing, no extra state to write to stream
}
@Override
protected MultiValuesSourceAggregatorFactory<Numeric, ?> innerBuild(SearchContext context,
Map<String, ValuesSourceConfig<Numeric>> configs,
DocValueFormat format,
AggregatorFactory<?> parent,
Builder subFactoriesBuilder) throws IOException {
return new WeightedAvgAggregatorFactory(name, configs, format, context, parent, subFactoriesBuilder, metaData);
}
@Override
public XContentBuilder doXContentBody(XContentBuilder builder, ToXContent.Params params) throws IOException {
return builder;
}
@Override
protected int innerHashCode() {
return 0;
}
@Override
protected boolean innerEquals(Object obj) {
return true;
}
@Override
public String getType() {
return NAME;
}
}

View File

@ -0,0 +1,158 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics.weighted_avg;
import org.apache.lucene.index.LeafReaderContext;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.MultiValuesSource;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder.VALUE_FIELD;
import static org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder.WEIGHT_FIELD;
public class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
private final MultiValuesSource.NumericMultiValuesSource valuesSources;
private DoubleArray weights;
private DoubleArray sums;
private DoubleArray sumCompensations;
private DoubleArray weightCompensations;
private DocValueFormat format;
public WeightedAvgAggregator(String name, MultiValuesSource.NumericMultiValuesSource valuesSources, DocValueFormat format,
SearchContext context, Aggregator parent, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
super(name, context, parent, pipelineAggregators, metaData);
this.valuesSources = valuesSources;
this.format = format;
if (valuesSources != null) {
final BigArrays bigArrays = context.bigArrays();
weights = bigArrays.newDoubleArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
sumCompensations = bigArrays.newDoubleArray(1, true);
weightCompensations = bigArrays.newDoubleArray(1, true);
}
}
@Override
public boolean needsScores() {
return valuesSources != null && valuesSources.needsScores();
}
@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
final LeafBucketCollector sub) throws IOException {
if (valuesSources == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx);
final SortedNumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), ctx);
return new LeafBucketCollectorBase(sub, docValues) {
@Override
public void collect(int doc, long bucket) throws IOException {
weights = bigArrays.grow(weights, bucket + 1);
sums = bigArrays.grow(sums, bucket + 1);
sumCompensations = bigArrays.grow(sumCompensations, bucket + 1);
weightCompensations = bigArrays.grow(weightCompensations, bucket + 1);
if (docValues.advanceExact(doc) && docWeights.advanceExact(doc)) {
if (docWeights.docValueCount() > 1) {
throw new AggregationExecutionException("Encountered more than one weight for a " +
"single document. Use a script to combine multiple weights-per-doc into a single value.");
}
// There should always be one weight if advanceExact lands us here, either
// a real weight or a `missing` weight
assert docWeights.docValueCount() == 1;
final double weight = docWeights.nextValue();
final int numValues = docValues.docValueCount();
assert numValues > 0;
for (int i = 0; i < numValues; i++) {
kahanSum(docValues.nextValue() * weight, sums, sumCompensations, bucket);
kahanSum(weight, weights, weightCompensations, bucket);
}
}
}
};
}
private static void kahanSum(double value, DoubleArray values, DoubleArray compensations, long bucket) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = values.get(bucket);
double compensation = compensations.get(bucket);
if (Double.isFinite(value) == false) {
sum += value;
} else if (Double.isFinite(sum)) {
double corrected = value - compensation;
double newSum = sum + corrected;
compensation = (newSum - sum) - corrected;
sum = newSum;
}
values.set(bucket, sum);
compensations.set(bucket, compensation);
}
@Override
public double metric(long owningBucketOrd) {
if (valuesSources == null || owningBucketOrd >= sums.size()) {
return Double.NaN;
}
return sums.get(owningBucketOrd) / weights.get(owningBucketOrd);
}
@Override
public InternalAggregation buildAggregation(long bucket) {
if (valuesSources == null || bucket >= sums.size()) {
return buildEmptyAggregation();
}
return new InternalWeightedAvg(name, sums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData());
}
@Override
public InternalAggregation buildEmptyAggregation() {
return new InternalWeightedAvg(name, 0.0, 0L, format, pipelineAggregators(), metaData());
}
@Override
public void doClose() {
Releasables.close(weights, sums, sumCompensations, weightCompensations);
}
}

View File

@ -0,0 +1,64 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics.weighted_avg;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.MultiValuesSource;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.List;
import java.util.Map;
public class WeightedAvgAggregatorFactory extends MultiValuesSourceAggregatorFactory<Numeric, WeightedAvgAggregatorFactory> {
public WeightedAvgAggregatorFactory(String name, Map<String, ValuesSourceConfig<Numeric>> configs,
DocValueFormat format, SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
super(name, configs, format, context, parent, subFactoriesBuilder, metaData);
}
@Override
protected Aggregator createUnmapped(Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData)
throws IOException {
return new WeightedAvgAggregator(name, null, format, context, parent, pipelineAggregators, metaData);
}
@Override
protected Aggregator doCreateInternal(Map<String, ValuesSourceConfig<Numeric>> configs, DocValueFormat format,
Aggregator parent, boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
MultiValuesSource.NumericMultiValuesSource numericMultiVS
= new MultiValuesSource.NumericMultiValuesSource(configs, context.getQueryShardContext());
if (numericMultiVS.areValuesSourcesEmpty()) {
return createUnmapped(parent, pipelineAggregators, metaData);
}
return new WeightedAvgAggregator(name, numericMultiVS, format, context, parent, pipelineAggregators, metaData);
}
}

View File

@ -0,0 +1,93 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.support;
import org.apache.lucene.index.LeafReaderContext;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.index.query.QueryShardContext;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* Class to encapsulate a set of ValuesSource objects labeled by field name
*/
public abstract class MultiValuesSource <VS extends ValuesSource> {
protected Map<String, VS> values;
public static class NumericMultiValuesSource extends MultiValuesSource<ValuesSource.Numeric> {
public NumericMultiValuesSource(Map<String, ValuesSourceConfig<ValuesSource.Numeric>> valuesSourceConfigs,
QueryShardContext context) throws IOException {
values = new HashMap<>(valuesSourceConfigs.size());
for (Map.Entry<String, ValuesSourceConfig<ValuesSource.Numeric>> entry : valuesSourceConfigs.entrySet()) {
values.put(entry.getKey(), entry.getValue().toValuesSource(context));
}
}
public SortedNumericDoubleValues getField(String fieldName, LeafReaderContext ctx) throws IOException {
ValuesSource.Numeric value = values.get(fieldName);
if (value == null) {
throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource");
}
return value.doubleValues(ctx);
}
}
public static class BytesMultiValuesSource extends MultiValuesSource<ValuesSource.Bytes> {
public BytesMultiValuesSource(Map<String, ValuesSourceConfig<ValuesSource.Bytes>> valuesSourceConfigs,
QueryShardContext context) throws IOException {
values = new HashMap<>(valuesSourceConfigs.size());
for (Map.Entry<String, ValuesSourceConfig<ValuesSource.Bytes>> entry : valuesSourceConfigs.entrySet()) {
values.put(entry.getKey(), entry.getValue().toValuesSource(context));
}
}
public Object getField(String fieldName, LeafReaderContext ctx) throws IOException {
ValuesSource.Bytes value = values.get(fieldName);
if (value == null) {
throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource");
}
return value.bytesValues(ctx);
}
}
public static class GeoPointValuesSource extends MultiValuesSource<ValuesSource.GeoPoint> {
public GeoPointValuesSource(Map<String, ValuesSourceConfig<ValuesSource.GeoPoint>> valuesSourceConfigs,
QueryShardContext context) throws IOException {
values = new HashMap<>(valuesSourceConfigs.size());
for (Map.Entry<String, ValuesSourceConfig<ValuesSource.GeoPoint>> entry : valuesSourceConfigs.entrySet()) {
values.put(entry.getKey(), entry.getValue().toValuesSource(context));
}
}
}
public boolean needsScores() {
return values.values().stream().anyMatch(ValuesSource::needsScores);
}
public String[] fieldNames() {
return values.keySet().toArray(new String[0]);
}
public boolean areValuesSourcesEmpty() {
return values.values().stream().allMatch(Objects::isNull);
}
}

View File

@ -0,0 +1,268 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.support;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationInitializationException;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactories.Builder;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* Similar to {@link ValuesSourceAggregationBuilder}, except it references multiple ValuesSources (e.g. so that an aggregation
* can pull values from multiple fields).
*
* A limitation of this class is that all the ValuesSource's being refereenced must be of the same type.
*/
public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSource, AB extends MultiValuesSourceAggregationBuilder<VS, AB>>
extends AbstractAggregationBuilder<AB> {
public abstract static class LeafOnly<VS extends ValuesSource, AB extends MultiValuesSourceAggregationBuilder<VS, AB>>
extends MultiValuesSourceAggregationBuilder<VS, AB> {
protected LeafOnly(String name, ValueType targetValueType) {
super(name, targetValueType);
}
protected LeafOnly(LeafOnly<VS, AB> clone, Builder factoriesBuilder, Map<String, Object> metaData) {
super(clone, factoriesBuilder, metaData);
if (factoriesBuilder.count() > 0) {
throw new AggregationInitializationException("Aggregator [" + name + "] of type ["
+ getType() + "] cannot accept sub-aggregations");
}
}
/**
* Read from a stream that does not serialize its targetValueType. This should be used by most subclasses.
*/
protected LeafOnly(StreamInput in, ValueType targetValueType) throws IOException {
super(in, targetValueType);
}
@Override
public AB subAggregations(Builder subFactories) {
throw new AggregationInitializationException("Aggregator [" + name + "] of type [" +
getType() + "] cannot accept sub-aggregations");
}
}
private Map<String, MultiValuesSourceFieldConfig> fields = new HashMap<>();
private final ValueType targetValueType;
private ValueType valueType = null;
private String format = null;
protected MultiValuesSourceAggregationBuilder(String name, ValueType targetValueType) {
super(name);
this.targetValueType = targetValueType;
}
protected MultiValuesSourceAggregationBuilder(MultiValuesSourceAggregationBuilder<VS, AB> clone,
Builder factoriesBuilder, Map<String, Object> metaData) {
super(clone, factoriesBuilder, metaData);
this.fields = new HashMap<>(clone.fields);
this.targetValueType = clone.targetValueType;
this.valueType = clone.valueType;
this.format = clone.format;
}
protected MultiValuesSourceAggregationBuilder(StreamInput in, ValueType targetValueType)
throws IOException {
super(in);
assert false == serializeTargetValueType() : "Wrong read constructor called for subclass that provides its targetValueType";
this.targetValueType = targetValueType;
read(in);
}
/**
* Read from a stream.
*/
@SuppressWarnings("unchecked")
private void read(StreamInput in) throws IOException {
fields = in.readMap(StreamInput::readString, MultiValuesSourceFieldConfig::new);
valueType = in.readOptionalWriteable(ValueType::readFromStream);
format = in.readOptionalString();
}
@Override
protected final void doWriteTo(StreamOutput out) throws IOException {
if (serializeTargetValueType()) {
out.writeOptionalWriteable(targetValueType);
}
out.writeMap(fields, StreamOutput::writeString, (o, value) -> value.writeTo(o));
out.writeOptionalWriteable(valueType);
out.writeOptionalString(format);
innerWriteTo(out);
}
/**
* Write subclass' state to the stream
*/
protected abstract void innerWriteTo(StreamOutput out) throws IOException;
@SuppressWarnings("unchecked")
protected AB field(String propertyName, MultiValuesSourceFieldConfig config) {
if (config == null) {
throw new IllegalArgumentException("[config] must not be null: [" + name + "]");
}
this.fields.put(propertyName, config);
return (AB) this;
}
public Map<String, MultiValuesSourceFieldConfig> fields() {
return fields;
}
/**
* Sets the {@link ValueType} for the value produced by this aggregation
*/
@SuppressWarnings("unchecked")
public AB valueType(ValueType valueType) {
if (valueType == null) {
throw new IllegalArgumentException("[valueType] must not be null: [" + name + "]");
}
this.valueType = valueType;
return (AB) this;
}
/**
* Gets the {@link ValueType} for the value produced by this aggregation
*/
public ValueType valueType() {
return valueType;
}
/**
* Sets the format to use for the output of the aggregation.
*/
@SuppressWarnings("unchecked")
public AB format(String format) {
if (format == null) {
throw new IllegalArgumentException("[format] must not be null: [" + name + "]");
}
this.format = format;
return (AB) this;
}
/**
* Gets the format to use for the output of the aggregation.
*/
public String format() {
return format;
}
@Override
protected final MultiValuesSourceAggregatorFactory<VS, ?> doBuild(SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException {
ValueType finalValueType = this.valueType != null ? this.valueType : targetValueType;
Map<String, ValuesSourceConfig<VS>> configs = new HashMap<>(fields.size());
fields.forEach((key, value) -> {
ValuesSourceConfig<VS> config = ValuesSourceConfig.resolve(context.getQueryShardContext(), finalValueType,
value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format);
configs.put(key, config);
});
DocValueFormat docValueFormat = resolveFormat(format, finalValueType);
return innerBuild(context, configs, docValueFormat, parent, subFactoriesBuilder);
}
private static DocValueFormat resolveFormat(@Nullable String format, @Nullable ValueType valueType) {
if (valueType == null) {
return DocValueFormat.RAW; // we can't figure it out
}
DocValueFormat valueFormat = valueType.defaultFormat;
if (valueFormat instanceof DocValueFormat.Decimal && format != null) {
valueFormat = new DocValueFormat.Decimal(format);
}
return valueFormat;
}
protected abstract MultiValuesSourceAggregatorFactory<VS, ?> innerBuild(SearchContext context,
Map<String, ValuesSourceConfig<VS>> configs, DocValueFormat format, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException;
/**
* Should this builder serialize its targetValueType? Defaults to false. All subclasses that override this to true
* should use the three argument read constructor rather than the four argument version.
*/
protected boolean serializeTargetValueType() {
return false;
}
@Override
public final XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (fields != null) {
builder.field(CommonFields.FIELDS.getPreferredName(), fields);
}
if (format != null) {
builder.field(CommonFields.FORMAT.getPreferredName(), format);
}
if (valueType != null) {
builder.field(CommonFields.VALUE_TYPE.getPreferredName(), valueType.getPreferredName());
}
doXContentBody(builder, params);
builder.endObject();
return builder;
}
protected abstract XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException;
@Override
protected final int doHashCode() {
return Objects.hash(fields, format, targetValueType, valueType, innerHashCode());
}
protected abstract int innerHashCode();
@Override
protected final boolean doEquals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
MultiValuesSourceAggregationBuilder that = (MultiValuesSourceAggregationBuilder) other;
return Objects.equals(this.fields, that.fields)
&& Objects.equals(this.format, that.format)
&& Objects.equals(this.valueType, that.valueType);
}
protected abstract boolean innerEquals(Object obj);
}

View File

@ -0,0 +1,64 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.support;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.List;
import java.util.Map;
public abstract class MultiValuesSourceAggregatorFactory<VS extends ValuesSource, AF extends MultiValuesSourceAggregatorFactory<VS, AF>>
extends AggregatorFactory<AF> {
protected final Map<String, ValuesSourceConfig<VS>> configs;
protected final DocValueFormat format;
public MultiValuesSourceAggregatorFactory(String name, Map<String, ValuesSourceConfig<VS>> configs,
DocValueFormat format, SearchContext context,
AggregatorFactory<?> parent, AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
super(name, context, parent, subFactoriesBuilder, metaData);
this.configs = configs;
this.format = format;
}
@Override
public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBucket, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
return doCreateInternal(configs, format, parent, collectsFromSingleBucket,
pipelineAggregators, metaData);
}
protected abstract Aggregator createUnmapped(Aggregator parent, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException;
protected abstract Aggregator doCreateInternal(Map<String, ValuesSourceConfig<VS>> configs,
DocValueFormat format, Aggregator parent, boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException;
}

View File

@ -0,0 +1,186 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.support;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.script.Script;
import org.joda.time.DateTimeZone;
import java.io.IOException;
import java.util.function.BiFunction;
public class MultiValuesSourceFieldConfig implements Writeable, ToXContentFragment {
private String fieldName;
private Object missing;
private Script script;
private DateTimeZone timeZone;
private static final String NAME = "field_config";
public static final BiFunction<Boolean, Boolean, ObjectParser<MultiValuesSourceFieldConfig.Builder, Void>> PARSER
= (scriptable, timezoneAware) -> {
ObjectParser<MultiValuesSourceFieldConfig.Builder, Void> parser
= new ObjectParser<>(MultiValuesSourceFieldConfig.NAME, MultiValuesSourceFieldConfig.Builder::new);
parser.declareString(MultiValuesSourceFieldConfig.Builder::setFieldName, ParseField.CommonFields.FIELD);
parser.declareField(MultiValuesSourceFieldConfig.Builder::setMissing, XContentParser::objectText,
ParseField.CommonFields.MISSING, ObjectParser.ValueType.VALUE);
if (scriptable) {
parser.declareField(MultiValuesSourceFieldConfig.Builder::setScript,
(p, context) -> Script.parse(p),
Script.SCRIPT_PARSE_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING);
}
if (timezoneAware) {
parser.declareField(MultiValuesSourceFieldConfig.Builder::setTimeZone, p -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return DateTimeZone.forID(p.text());
} else {
return DateTimeZone.forOffsetHours(p.intValue());
}
}, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG);
}
return parser;
};
private MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, DateTimeZone timeZone) {
this.fieldName = fieldName;
this.missing = missing;
this.script = script;
this.timeZone = timeZone;
}
public MultiValuesSourceFieldConfig(StreamInput in) throws IOException {
this.fieldName = in.readString();
this.missing = in.readGenericValue();
this.script = in.readOptionalWriteable(Script::new);
this.timeZone = in.readOptionalTimeZone();
}
public Object getMissing() {
return missing;
}
public Script getScript() {
return script;
}
public DateTimeZone getTimeZone() {
return timeZone;
}
public String getFieldName() {
return fieldName;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeGenericValue(missing);
out.writeOptionalWriteable(script);
out.writeOptionalTimeZone(timeZone);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (missing != null) {
builder.field(ParseField.CommonFields.MISSING.getPreferredName(), missing);
}
if (script != null) {
builder.field(Script.SCRIPT_PARSE_FIELD.getPreferredName(), script);
}
if (fieldName != null) {
builder.field(ParseField.CommonFields.FIELD.getPreferredName(), fieldName);
}
if (timeZone != null) {
builder.field(ParseField.CommonFields.TIME_ZONE.getPreferredName(), timeZone);
}
return builder;
}
public static class Builder {
private String fieldName;
private Object missing = null;
private Script script = null;
private DateTimeZone timeZone = null;
public String getFieldName() {
return fieldName;
}
public Builder setFieldName(String fieldName) {
this.fieldName = fieldName;
return this;
}
public Object getMissing() {
return missing;
}
public Builder setMissing(Object missing) {
this.missing = missing;
return this;
}
public Script getScript() {
return script;
}
public Builder setScript(Script script) {
this.script = script;
return this;
}
public DateTimeZone getTimeZone() {
return timeZone;
}
public Builder setTimeZone(DateTimeZone timeZone) {
this.timeZone = timeZone;
return this;
}
public MultiValuesSourceFieldConfig build() {
if (Strings.isNullOrEmpty(fieldName) && script == null) {
throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName()
+ "] and [" + Script.SCRIPT_PARSE_FIELD.getPreferredName() + "] cannot both be null. " +
"Please specify one or the other.");
}
if (Strings.isNullOrEmpty(fieldName) == false && script != null) {
throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName()
+ "] and [" + Script.SCRIPT_PARSE_FIELD.getPreferredName() + "] cannot both be configured. " +
"Please specify one or the other.");
}
return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone);
}
}
}

View File

@ -0,0 +1,59 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.support;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.xcontent.AbstractObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
public final class MultiValuesSourceParseHelper {
public static <VS extends ValuesSource, T> void declareCommon(
AbstractObjectParser<? extends MultiValuesSourceAggregationBuilder<VS, ?>, T> objectParser, boolean formattable,
ValueType targetValueType) {
objectParser.declareField(MultiValuesSourceAggregationBuilder::valueType, p -> {
ValueType valueType = ValueType.resolveForScript(p.text());
if (targetValueType != null && valueType.isNotA(targetValueType)) {
throw new ParsingException(p.getTokenLocation(),
"Aggregation [" + objectParser.getName() + "] was configured with an incompatible value type ["
+ valueType + "]. It can only work on value of type ["
+ targetValueType + "]");
}
return valueType;
}, ValueType.VALUE_TYPE, ObjectParser.ValueType.STRING);
if (formattable) {
objectParser.declareField(MultiValuesSourceAggregationBuilder::format, XContentParser::text,
ParseField.CommonFields.FORMAT, ObjectParser.ValueType.STRING);
}
}
public static <VS extends ValuesSource, T> void declareField(String fieldName,
AbstractObjectParser<? extends MultiValuesSourceAggregationBuilder<VS, ?>, T> objectParser,
boolean scriptable, boolean timezoneAware) {
objectParser.declareField((o, fieldConfig) -> o.field(fieldName, fieldConfig.build()),
(p, c) -> MultiValuesSourceFieldConfig.PARSER.apply(scriptable, timezoneAware).parse(p, null),
new ParseField(fieldName), ObjectParser.ValueType.OBJECT);
}
}

View File

@ -19,6 +19,7 @@
package org.elasticsearch.search.aggregations.support;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
@ -95,6 +96,8 @@ public enum ValueType implements Writeable {
private final byte id;
private String preferredName;
public static final ParseField VALUE_TYPE = new ParseField("value_type", "valueType");
ValueType(byte id, String description, String preferredName, ValuesSourceType valuesSourceType,
Class<? extends IndexFieldData> fieldDataType, DocValueFormat defaultFormat) {
this.id = id;
@ -112,7 +115,7 @@ public enum ValueType implements Writeable {
public String getPreferredName() {
return preferredName;
}
public ValuesSourceType getValuesSourceType() {
return valuesSourceType;
}

View File

@ -28,7 +28,6 @@ import org.elasticsearch.script.Script;
import org.joda.time.DateTimeZone;
public final class ValuesSourceParserHelper {
static final ParseField TIME_ZONE = new ParseField("time_zone");
private ValuesSourceParserHelper() {} // utility class, no instantiation
@ -62,10 +61,10 @@ public final class ValuesSourceParserHelper {
objectParser.declareField(ValuesSourceAggregationBuilder::field, XContentParser::text,
new ParseField("field"), ObjectParser.ValueType.STRING);
ParseField.CommonFields.FIELD, ObjectParser.ValueType.STRING);
objectParser.declareField(ValuesSourceAggregationBuilder::missing, XContentParser::objectText,
new ParseField("missing"), ObjectParser.ValueType.VALUE);
ParseField.CommonFields.MISSING, ObjectParser.ValueType.VALUE);
objectParser.declareField(ValuesSourceAggregationBuilder::valueType, p -> {
ValueType valueType = ValueType.resolveForScript(p.text());
@ -76,11 +75,11 @@ public final class ValuesSourceParserHelper {
+ targetValueType + "]");
}
return valueType;
}, new ParseField("value_type", "valueType"), ObjectParser.ValueType.STRING);
}, ValueType.VALUE_TYPE, ObjectParser.ValueType.STRING);
if (formattable) {
objectParser.declareField(ValuesSourceAggregationBuilder::format, XContentParser::text,
new ParseField("format"), ObjectParser.ValueType.STRING);
ParseField.CommonFields.FORMAT, ObjectParser.ValueType.STRING);
}
if (scriptable) {
@ -96,7 +95,7 @@ public final class ValuesSourceParserHelper {
} else {
return DateTimeZone.forOffsetHours(p.intValue());
}
}, TIME_ZONE, ObjectParser.ValueType.LONG);
}, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG);
}
}

View File

@ -19,9 +19,37 @@
package org.elasticsearch.search.aggregations.support;
public enum ValuesSourceType {
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import java.io.IOException;
import java.util.Locale;
public enum ValuesSourceType implements Writeable {
ANY,
NUMERIC,
BYTES,
GEOPOINT;
public static final ParseField VALUE_SOURCE_TYPE = new ParseField("value_source_type");
public static ValuesSourceType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
public static ValuesSourceType fromStream(StreamInput in) throws IOException {
return in.readEnum(ValuesSourceType.class);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
ValuesSourceType state = this;
out.writeEnum(state);
}
public String value() {
return name().toLowerCase(Locale.ROOT);
}
}

View File

@ -0,0 +1,428 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.metrics.weighted_avg;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
import org.joda.time.DateTimeZone;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.function.Consumer;
import static java.util.Collections.singleton;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
public class WeightedAvgAggregatorTests extends AggregatorTestCase {
public void testNoDocs() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> {
// Intentionally not writing any docs
}, avg -> {
assertEquals(Double.NaN, avg.getValue(), 0);
});
}
public void testNoMatchingField() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> {
iw.addDocument(singleton(new SortedNumericDocValuesField("wrong_number", 7)));
iw.addDocument(singleton(new SortedNumericDocValuesField("wrong_number", 3)));
}, avg -> {
assertEquals(Double.NaN, avg.getValue(), 0);
});
}
public void testSomeMatchesSortedNumericDocValuesNoWeight() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 7),
new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2),
new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3),
new SortedNumericDocValuesField("weight_field", 1)));
}, avg -> {
assertEquals(4, avg.getValue(), 0);
});
}
public void testSomeMatchesSortedNumericDocValuesWeights() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 7),
new SortedNumericDocValuesField("weight_field", 2)));
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2),
new SortedNumericDocValuesField("weight_field", 3)));
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3),
new SortedNumericDocValuesField("weight_field", 3)));
}, avg -> {
// (7*2 + 2*3 + 3*3) / (2+3+3) == 3.625
assertEquals(3.625, avg.getValue(), 0);
});
}
public void testSomeMatchesNumericDocValues() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(new DocValuesFieldExistsQuery("value_field"), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", 7),
new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", 2),
new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", 3),
new SortedNumericDocValuesField("weight_field", 1)));
}, avg -> {
assertEquals(4, avg.getValue(), 0);
});
}
public void testQueryFiltering() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(IntPoint.newRangeQuery("value_field", 0, 3), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(new IntPoint("value_field", 7), new SortedNumericDocValuesField("value_field", 7),
new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new IntPoint("value_field", 1), new SortedNumericDocValuesField("value_field", 2),
new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new IntPoint("value_field", 3), new SortedNumericDocValuesField("value_field", 3),
new SortedNumericDocValuesField("weight_field", 1)));
}, avg -> {
assertEquals(2.5, avg.getValue(), 0);
});
}
public void testQueryFilteringWeights() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(IntPoint.newRangeQuery("filter_field", 0, 3), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(new IntPoint("filter_field", 7), new SortedNumericDocValuesField("value_field", 7),
new SortedNumericDocValuesField("weight_field", 2)));
iw.addDocument(Arrays.asList(new IntPoint("filter_field", 2), new SortedNumericDocValuesField("value_field", 2),
new SortedNumericDocValuesField("weight_field", 3)));
iw.addDocument(Arrays.asList(new IntPoint("filter_field", 3), new SortedNumericDocValuesField("value_field", 3),
new SortedNumericDocValuesField("weight_field", 4)));
}, avg -> {
double value = (2.0*3.0 + 3.0*4.0) / (3.0+4.0);
assertEquals(value, avg.getValue(), 0);
});
}
public void testQueryFiltersAll() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(IntPoint.newRangeQuery("value_field", -1, 0), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(new IntPoint("value_field", 7), new SortedNumericDocValuesField("value_field", 7)));
iw.addDocument(Arrays.asList(new IntPoint("value_field", 1), new SortedNumericDocValuesField("value_field", 2)));
iw.addDocument(Arrays.asList(new IntPoint("value_field", 3), new SortedNumericDocValuesField("value_field", 7)));
}, avg -> {
assertEquals(Double.NaN, avg.getValue(), 0);
});
}
public void testQueryFiltersAllWeights() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(IntPoint.newRangeQuery("value_field", -1, 0), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(new IntPoint("filter_field", 7), new SortedNumericDocValuesField("value_field", 7),
new SortedNumericDocValuesField("weight_field", 2)));
iw.addDocument(Arrays.asList(new IntPoint("filter_field", 2), new SortedNumericDocValuesField("value_field", 2),
new SortedNumericDocValuesField("weight_field", 3)));
iw.addDocument(Arrays.asList(new IntPoint("filter_field", 3), new SortedNumericDocValuesField("value_field", 3),
new SortedNumericDocValuesField("weight_field", 4)));
}, avg -> {
assertEquals(Double.NaN, avg.getValue(), 0);
});
}
public void testValueSetMissing() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder()
.setFieldName("value_field")
.setMissing(2)
.build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> {
iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("weight_field", 2)));
iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("weight_field", 3)));
iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("weight_field", 4)));
}, avg -> {
double value = (2.0*2.0 + 2.0*3.0 + 2.0*4.0) / (2.0+3.0+4.0);
assertEquals(value, avg.getValue(), 0);
});
}
public void testWeightSetMissing() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder()
.setFieldName("weight_field")
.setMissing(2)
.build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> {
iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 2)));
iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 3)));
iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 4)));
}, avg -> {
double value = (2.0*2.0 + 3.0*2.0 + 4.0*2.0) / (2.0+2.0+2.0);
assertEquals(value, avg.getValue(), 0);
});
}
public void testWeightSetTimezone() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder()
.setFieldName("weight_field")
.setTimeZone(DateTimeZone.UTC)
.build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2),
new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3),
new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4),
new SortedNumericDocValuesField("weight_field", 1)));
}, avg -> {
fail("Should not have executed test case");
}));
assertThat(e.getMessage(), equalTo("Field [weight_field] of type [long] does not support custom time zones"));
}
public void testValueSetTimezone() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder()
.setFieldName("value_field")
.setTimeZone(DateTimeZone.UTC)
.build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2),
new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3),
new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4),
new SortedNumericDocValuesField("weight_field", 1)));
}, avg -> {
fail("Should not have executed test case");
}));
assertThat(e.getMessage(), equalTo("Field [value_field] of type [long] does not support custom time zones"));
}
public void testMultiValues() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder()
.setFieldName("value_field")
.build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2),
new SortedNumericDocValuesField("value_field", 3), new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3),
new SortedNumericDocValuesField("value_field", 4), new SortedNumericDocValuesField("weight_field", 1)));
iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4),
new SortedNumericDocValuesField("value_field", 5), new SortedNumericDocValuesField("weight_field", 1)));
}, avg -> {
double value = (((2.0+3.0)/2.0) + ((3.0+4.0)/2.0) + ((4.0+5.0)/2.0)) / (1.0+1.0+1.0);
assertEquals(value, avg.getValue(), 0);
});
}
public void testMultiWeight() throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder()
.setFieldName("weight_field")
.build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
AggregationExecutionException e = expectThrows(AggregationExecutionException.class,
() -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> {
iw.addDocument(Arrays.asList(
new SortedNumericDocValuesField("value_field", 2),
new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3)));
iw.addDocument(Arrays.asList(
new SortedNumericDocValuesField("value_field", 3),
new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4)));
iw.addDocument(Arrays.asList(
new SortedNumericDocValuesField("value_field", 4),
new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5)));
}, avg -> {
fail("Should have thrown exception");
}));
assertThat(e.getMessage(), containsString("Encountered more than one weight for a single document. " +
"Use a script to combine multiple weights-per-doc into a single value."));
}
public void testSummationAccuracy() throws IOException {
// Summing up a normal array and expect an accurate value
double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7};
verifyAvgOfDoubles(values, 0.9, 0d);
// Summing up an array which contains NaN and infinities and expect a result same as naive summation
int n = randomIntBetween(5, 10);
values = new double[n];
double sum = 0;
for (int i = 0; i < n; i++) {
values[i] = frequently()
? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY)
: randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true);
sum += values[i];
}
verifyAvgOfDoubles(values, sum / n, 1e-10);
// Summing up some big double values and expect infinity result
n = randomIntBetween(5, 10);
double[] largeValues = new double[n];
for (int i = 0; i < n; i++) {
largeValues[i] = Double.MAX_VALUE;
}
verifyAvgOfDoubles(largeValues, Double.POSITIVE_INFINITY, 0d);
for (int i = 0; i < n; i++) {
largeValues[i] = -Double.MAX_VALUE;
}
verifyAvgOfDoubles(largeValues, Double.NEGATIVE_INFINITY, 0d);
}
private void verifyAvgOfDoubles(double[] values, double expected, double delta) throws IOException {
MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build();
MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build();
WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name")
.value(valueConfig)
.weight(weightConfig);
testCase(new MatchAllDocsQuery(), aggregationBuilder,
iw -> {
for (double value : values) {
iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", NumericUtils.doubleToSortableLong(value)),
new SortedNumericDocValuesField("weight_field", NumericUtils.doubleToSortableLong(1.0))));
}
},
avg -> assertEquals(expected, avg.getValue(), delta),
NumberFieldMapper.NumberType.DOUBLE
);
}
private void testCase(Query query, WeightedAvgAggregationBuilder aggregationBuilder,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalWeightedAvg> verify) throws IOException {
testCase(query, aggregationBuilder, buildIndex, verify, NumberFieldMapper.NumberType.LONG);
}
private void testCase(Query query, WeightedAvgAggregationBuilder aggregationBuilder,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalWeightedAvg> verify,
NumberFieldMapper.NumberType fieldNumberType) throws IOException {
Directory directory = newDirectory();
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
buildIndex.accept(indexWriter);
indexWriter.close();
IndexReader indexReader = DirectoryReader.open(directory);
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
try {
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType);
fieldType.setName("value_field");
fieldType.setHasDocValues(true);
MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(fieldNumberType);
fieldType2.setName("weight_field");
fieldType2.setHasDocValues(true);
WeightedAvgAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType, fieldType2);
aggregator.preCollection();
indexSearcher.search(query, aggregator);
aggregator.postCollection();
verify.accept((InternalWeightedAvg) aggregator.buildAggregation(0L));
} finally {
indexReader.close();
directory.close();
}
}
}

View File

@ -0,0 +1,38 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.support;
import org.elasticsearch.script.Script;
import org.elasticsearch.test.ESTestCase;
import static org.hamcrest.Matchers.equalTo;
public class MultiValuesSourceFieldConfigTests extends ESTestCase {
public void testMissingFieldScript() {
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new MultiValuesSourceFieldConfig.Builder().build());
assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be null. Please specify one or the other."));
}
public void testBothFieldScript() {
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> new MultiValuesSourceFieldConfig.Builder().setFieldName("foo").setScript(new Script("foo")).build());
assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be configured. Please specify one or the other."));
}
}

View File

@ -208,7 +208,7 @@ public class RollupRequestTranslator {
private static List<AggregationBuilder> translateDateHistogram(DateHistogramAggregationBuilder source,
List<QueryBuilder> filterConditions,
NamedWriteableRegistry registry) {
return translateVSAggBuilder(source, filterConditions, registry, () -> {
DateHistogramAggregationBuilder rolledDateHisto
= new DateHistogramAggregationBuilder(source.getName());