[7.x] Add support for filters to T-Test aggregation (#54980) (#55066)

Adds support for filters to T-Test aggregation. The filters can be used to
select populations based on some criteria and use values from the same or
different fields.

Closes #53692
This commit is contained in:
Igor Motov 2020-04-13 12:28:58 -04:00 committed by GitHub
parent a2fafa6af4
commit 51c6f69e02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 510 additions and 83 deletions

View File

@ -548,7 +548,7 @@ buildRestTests.setups['node_upgrade'] = '''
number_of_replicas: 1 number_of_replicas: 1
mappings: mappings:
properties: properties:
name: group:
type: keyword type: keyword
startup_time_before: startup_time_before:
type: long type: long
@ -560,17 +560,17 @@ buildRestTests.setups['node_upgrade'] = '''
refresh: true refresh: true
body: | body: |
{"index":{}} {"index":{}}
{"name": "A", "startup_time_before": 102, "startup_time_after": 89} {"group": "A", "startup_time_before": 102, "startup_time_after": 89}
{"index":{}} {"index":{}}
{"name": "B", "startup_time_before": 99, "startup_time_after": 93} {"group": "A", "startup_time_before": 99, "startup_time_after": 93}
{"index":{}} {"index":{}}
{"name": "C", "startup_time_before": 111, "startup_time_after": 72} {"group": "A", "startup_time_before": 111, "startup_time_after": 72}
{"index":{}} {"index":{}}
{"name": "D", "startup_time_before": 97, "startup_time_after": 98} {"group": "B", "startup_time_before": 97, "startup_time_after": 98}
{"index":{}} {"index":{}}
{"name": "E", "startup_time_before": 101, "startup_time_after": 102} {"group": "B", "startup_time_before": 101, "startup_time_after": 102}
{"index":{}} {"index":{}}
{"name": "F", "startup_time_before": 99, "startup_time_after": 98}''' {"group": "B", "startup_time_before": 99, "startup_time_after": 98}'''
// Used by iprange agg // Used by iprange agg
buildRestTests.setups['iprange'] = ''' buildRestTests.setups['iprange'] = '''

View File

@ -1,7 +1,7 @@
[role="xpack"] [role="xpack"]
[testenv="basic"] [testenv="basic"]
[[search-aggregations-metrics-ttest-aggregation]] [[search-aggregations-metrics-ttest-aggregation]]
=== TTest Aggregation === T-Test Aggregation
A `t_test` metrics aggregation that performs a statistical hypothesis test in which the test statistic follows a Student's t-distribution A `t_test` metrics aggregation that performs a statistical hypothesis test in which the test statistic follows a Student's t-distribution
under the null hypothesis on numeric values extracted from the aggregated documents or generated by provided scripts. In practice, this under the null hypothesis on numeric values extracted from the aggregated documents or generated by provided scripts. In practice, this
@ -43,8 +43,8 @@ GET node_upgrade/_search
} }
-------------------------------------------------- --------------------------------------------------
// TEST[setup:node_upgrade] // TEST[setup:node_upgrade]
<1> The field `startup_time_before` must be a numeric field <1> The field `startup_time_before` must be a numeric field.
<2> The field `startup_time_after` must be a numeric field <2> The field `startup_time_after` must be a numeric field.
<3> Since we have data from the same nodes, we are using paired t-test. <3> Since we have data from the same nodes, we are using paired t-test.
The response will return the p-value or probability value for the test. It is the probability of obtaining results at least as extreme as The response will return the p-value or probability value for the test. It is the probability of obtaining results at least as extreme as
@ -74,6 +74,69 @@ The `t_test` aggregation supports unpaired and paired two-sample t-tests. The ty
`"type": "homoscedastic"`:: performs two-sample equal variance test `"type": "homoscedastic"`:: performs two-sample equal variance test
`"type": "heteroscedastic"`:: performs two-sample unequal variance test (this is default) `"type": "heteroscedastic"`:: performs two-sample unequal variance test (this is default)
==== Filters
It is also possible to run unpaired t-test on different sets of records using filters. For example, if we want to test the difference
of startup times before upgrade between two different groups of nodes, we use the same field `startup_time_before` by separate groups of
nodes using terms filters on the group name field:
[source,console]
--------------------------------------------------
GET node_upgrade/_search
{
"size" : 0,
"aggs" : {
"startup_time_ttest" : {
"t_test" : {
"a" : {
"field" : "startup_time_before", <1>
"filter" : {
"term" : {
"group" : "A" <2>
}
}
},
"b" : {
"field" : "startup_time_before", <3>
"filter" : {
"term" : {
"group" : "B" <4>
}
}
},
"type" : "heteroscedastic" <5>
}
}
}
}
--------------------------------------------------
// TEST[setup:node_upgrade]
<1> The field `startup_time_before` must be a numeric field.
<2> Any query that separates two groups can be used here.
<3> We are using the same field
<4> but we are using different filters.
<5> Since we have data from different nodes, we cannot use paired t-test.
[source,console-result]
--------------------------------------------------
{
...
"aggregations": {
"startup_time_ttest": {
"value": 0.2981858007281437 <1>
}
}
}
--------------------------------------------------
// TESTRESPONSE[s/\.\.\./"took": $body.took,"timed_out": false,"_shards": $body._shards,"hits": $body.hits,/]
<1> The p-value.
In this example, we are using the same fields for both populations. However this is not a requirement and different fields and even
combination of fields and scripts can be used. Populations don't have to be in the same index either. If data sets are located in different
indices, the term filter on the <<mapping-index-field,`_index`>> field can be used to select populations.
==== Script ==== Script
The `t_test` metric supports scripting. For example, if we need to adjust out load times for the before values, we could use The `t_test` metric supports scripting. For example, if we need to adjust out load times for the before values, we could use
@ -108,7 +171,7 @@ GET node_upgrade/_search
// TEST[setup:node_upgrade] // TEST[setup:node_upgrade]
<1> The `field` parameter is replaced with a `script` parameter, which uses the <1> The `field` parameter is replaced with a `script` parameter, which uses the
script to generate values which percentiles are calculated on script to generate values which percentiles are calculated on.
<2> Scripting supports parameterized input just like any other script <2> Scripting supports parameterized input just like any other script.
<3> We can mix scripts and fields <3> We can mix scripts and fields.

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder;
@ -51,8 +52,8 @@ public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationB
ObjectParser.fromBuilder(NAME, WeightedAvgAggregationBuilder::new); ObjectParser.fromBuilder(NAME, WeightedAvgAggregationBuilder::new);
static { static {
MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC); MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC);
MultiValuesSourceParseHelper.declareField(VALUE_FIELD.getPreferredName(), PARSER, true, false); MultiValuesSourceParseHelper.declareField(VALUE_FIELD.getPreferredName(), PARSER, true, false, false);
MultiValuesSourceParseHelper.declareField(WEIGHT_FIELD.getPreferredName(), PARSER, true, false); MultiValuesSourceParseHelper.declareField(WEIGHT_FIELD.getPreferredName(), PARSER, true, false, false);
} }
public WeightedAvgAggregationBuilder(String name) { public WeightedAvgAggregationBuilder(String name) {
@ -99,10 +100,11 @@ public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationB
@Override @Override
protected MultiValuesSourceAggregatorFactory<Numeric> innerBuild(QueryShardContext queryShardContext, protected MultiValuesSourceAggregatorFactory<Numeric> innerBuild(QueryShardContext queryShardContext,
Map<String, ValuesSourceConfig<Numeric>> configs, Map<String, ValuesSourceConfig<Numeric>> configs,
DocValueFormat format, Map<String, QueryBuilder> filters,
AggregatorFactory parent, DocValueFormat format,
Builder subFactoriesBuilder) throws IOException { AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException {
return new WeightedAvgAggregatorFactory(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metadata); return new WeightedAvgAggregatorFactory(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metadata);
} }

View File

@ -22,6 +22,7 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
@ -168,13 +169,15 @@ public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSourc
ValueType finalValueType = this.valueType != null ? this.valueType : targetValueType; ValueType finalValueType = this.valueType != null ? this.valueType : targetValueType;
Map<String, ValuesSourceConfig<VS>> configs = new HashMap<>(fields.size()); Map<String, ValuesSourceConfig<VS>> configs = new HashMap<>(fields.size());
Map<String, QueryBuilder> filters = new HashMap<>(fields.size());
fields.forEach((key, value) -> { fields.forEach((key, value) -> {
ValuesSourceConfig<VS> config = ValuesSourceConfig.resolve(queryShardContext, finalValueType, ValuesSourceConfig<VS> config = ValuesSourceConfig.resolve(queryShardContext, finalValueType,
value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format); value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format);
configs.put(key, config); configs.put(key, config);
filters.put(key, value.getFilter());
}); });
DocValueFormat docValueFormat = resolveFormat(format, finalValueType); DocValueFormat docValueFormat = resolveFormat(format, finalValueType);
return innerBuild(queryShardContext, configs, docValueFormat, parent, subFactoriesBuilder); return innerBuild(queryShardContext, configs, filters, docValueFormat, parent, subFactoriesBuilder);
} }
@ -191,6 +194,7 @@ public abstract class MultiValuesSourceAggregationBuilder<VS extends ValuesSourc
protected abstract MultiValuesSourceAggregatorFactory<VS> innerBuild(QueryShardContext queryShardContext, protected abstract MultiValuesSourceAggregatorFactory<VS> innerBuild(QueryShardContext queryShardContext,
Map<String, ValuesSourceConfig<VS>> configs, Map<String, ValuesSourceConfig<VS>> configs,
Map<String, QueryBuilder> filters,
DocValueFormat format, AggregatorFactory parent, DocValueFormat format, AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException; Builder subFactoriesBuilder) throws IOException;

View File

@ -30,26 +30,30 @@ import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.script.Script; import org.elasticsearch.script.Script;
import java.io.IOException; import java.io.IOException;
import java.time.ZoneId; import java.time.ZoneId;
import java.time.ZoneOffset; import java.time.ZoneOffset;
import java.util.Objects; import java.util.Objects;
import java.util.function.BiFunction;
public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject { public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject {
private String fieldName; private final String fieldName;
private Object missing; private final Object missing;
private Script script; private final Script script;
private ZoneId timeZone; private final ZoneId timeZone;
private final QueryBuilder filter;
private static final String NAME = "field_config"; private static final String NAME = "field_config";
public static final BiFunction<Boolean, Boolean, ObjectParser<MultiValuesSourceFieldConfig.Builder, Void>> PARSER public static final ParseField FILTER = new ParseField("filter");
= (scriptable, timezoneAware) -> {
ObjectParser<MultiValuesSourceFieldConfig.Builder, Void> parser public static <C> ObjectParser<MultiValuesSourceFieldConfig.Builder, C> parserBuilder(boolean scriptable, boolean timezoneAware,
boolean filtered) {
ObjectParser<MultiValuesSourceFieldConfig.Builder, C> parser
= new ObjectParser<>(MultiValuesSourceFieldConfig.NAME, MultiValuesSourceFieldConfig.Builder::new); = new ObjectParser<>(MultiValuesSourceFieldConfig.NAME, MultiValuesSourceFieldConfig.Builder::new);
parser.declareString(MultiValuesSourceFieldConfig.Builder::setFieldName, ParseField.CommonFields.FIELD); parser.declareString(MultiValuesSourceFieldConfig.Builder::setFieldName, ParseField.CommonFields.FIELD);
@ -71,14 +75,21 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject
} }
}, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG); }, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG);
} }
if (filtered) {
parser.declareField(MultiValuesSourceFieldConfig.Builder::setFilter,
(p, context) -> AbstractQueryBuilder.parseInnerQueryBuilder(p),
FILTER, ObjectParser.ValueType.OBJECT);
}
return parser; return parser;
}; };
private MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, ZoneId timeZone) { protected MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, ZoneId timeZone, QueryBuilder filter) {
this.fieldName = fieldName; this.fieldName = fieldName;
this.missing = missing; this.missing = missing;
this.script = script; this.script = script;
this.timeZone = timeZone; this.timeZone = timeZone;
this.filter = filter;
} }
public MultiValuesSourceFieldConfig(StreamInput in) throws IOException { public MultiValuesSourceFieldConfig(StreamInput in) throws IOException {
@ -94,6 +105,11 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject
} else { } else {
this.timeZone = in.readOptionalZoneId(); this.timeZone = in.readOptionalZoneId();
} }
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
} else {
this.filter = null;
}
} }
public Object getMissing() { public Object getMissing() {
@ -112,6 +128,10 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject
return fieldName; return fieldName;
} }
public QueryBuilder getFilter() {
return filter;
}
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_6_0)) { if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
@ -126,6 +146,9 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject
} else { } else {
out.writeOptionalZoneId(timeZone); out.writeOptionalZoneId(timeZone);
} }
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeOptionalNamedWriteable(filter);
}
} }
@Override @Override
@ -143,6 +166,10 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject
if (timeZone != null) { if (timeZone != null) {
builder.field(ParseField.CommonFields.TIME_ZONE.getPreferredName(), timeZone.getId()); builder.field(ParseField.CommonFields.TIME_ZONE.getPreferredName(), timeZone.getId());
} }
if (filter != null) {
builder.field(FILTER.getPreferredName());
filter.toXContent(builder, params);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -155,12 +182,13 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject
return Objects.equals(fieldName, that.fieldName) return Objects.equals(fieldName, that.fieldName)
&& Objects.equals(missing, that.missing) && Objects.equals(missing, that.missing)
&& Objects.equals(script, that.script) && Objects.equals(script, that.script)
&& Objects.equals(timeZone, that.timeZone); && Objects.equals(timeZone, that.timeZone)
&& Objects.equals(filter, that.filter);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(fieldName, missing, script, timeZone); return Objects.hash(fieldName, missing, script, timeZone, filter);
} }
@Override @Override
@ -173,6 +201,7 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject
private Object missing = null; private Object missing = null;
private Script script = null; private Script script = null;
private ZoneId timeZone = null; private ZoneId timeZone = null;
private QueryBuilder filter = null;
public String getFieldName() { public String getFieldName() {
return fieldName; return fieldName;
@ -210,6 +239,11 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject
return this; return this;
} }
public Builder setFilter(QueryBuilder filter) {
this.filter = filter;
return this;
}
public MultiValuesSourceFieldConfig build() { public MultiValuesSourceFieldConfig build() {
if (Strings.isNullOrEmpty(fieldName) && script == null) { if (Strings.isNullOrEmpty(fieldName) && script == null) {
throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName() throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName()
@ -223,7 +257,7 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject
"Please specify one or the other."); "Please specify one or the other.");
} }
return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone); return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone, filter);
} }
} }
} }

View File

@ -50,10 +50,10 @@ public final class MultiValuesSourceParseHelper {
public static <VS extends ValuesSource, T> void declareField(String fieldName, public static <VS extends ValuesSource, T> void declareField(String fieldName,
AbstractObjectParser<? extends MultiValuesSourceAggregationBuilder<VS, ?>, T> objectParser, AbstractObjectParser<? extends MultiValuesSourceAggregationBuilder<VS, ?>, T> objectParser,
boolean scriptable, boolean timezoneAware) { boolean scriptable, boolean timezoneAware, boolean filterable) {
objectParser.declareField((o, fieldConfig) -> o.field(fieldName, fieldConfig.build()), objectParser.declareField((o, fieldConfig) -> o.field(fieldName, fieldConfig.build()),
(p, c) -> MultiValuesSourceFieldConfig.PARSER.apply(scriptable, timezoneAware).parse(p, null), (p, c) -> MultiValuesSourceFieldConfig.parserBuilder(scriptable, timezoneAware, filterable).parse(p, null),
new ParseField(fieldName), ObjectParser.ValueType.OBJECT); new ParseField(fieldName), ObjectParser.ValueType.OBJECT);
} }
} }

View File

@ -19,13 +19,20 @@
package org.elasticsearch.search.aggregations.support; package org.elasticsearch.search.aggregations.support;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.script.Script; import org.elasticsearch.script.Script;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException; import java.io.IOException;
import java.time.ZoneId; import java.time.ZoneId;
import java.util.Collections;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -33,7 +40,7 @@ public class MultiValuesSourceFieldConfigTests extends AbstractSerializingTestCa
@Override @Override
protected MultiValuesSourceFieldConfig doParseInstance(XContentParser parser) throws IOException { protected MultiValuesSourceFieldConfig doParseInstance(XContentParser parser) throws IOException {
return MultiValuesSourceFieldConfig.PARSER.apply(true, true).apply(parser, null).build(); return MultiValuesSourceFieldConfig.parserBuilder(true, true, true).apply(parser, null).build();
} }
@Override @Override
@ -41,8 +48,9 @@ public class MultiValuesSourceFieldConfigTests extends AbstractSerializingTestCa
String field = randomAlphaOfLength(10); String field = randomAlphaOfLength(10);
Object missing = randomBoolean() ? randomAlphaOfLength(10) : null; Object missing = randomBoolean() ? randomAlphaOfLength(10) : null;
ZoneId timeZone = randomBoolean() ? randomZone() : null; ZoneId timeZone = randomBoolean() ? randomZone() : null;
QueryBuilder filter = randomBoolean() ? QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)) : null;
return new MultiValuesSourceFieldConfig.Builder() return new MultiValuesSourceFieldConfig.Builder()
.setFieldName(field).setMissing(missing).setScript(null).setTimeZone(timeZone).build(); .setFieldName(field).setMissing(missing).setScript(null).setTimeZone(timeZone).setFilter(filter).build();
} }
@Override @Override
@ -60,4 +68,14 @@ public class MultiValuesSourceFieldConfigTests extends AbstractSerializingTestCa
() -> new MultiValuesSourceFieldConfig.Builder().setFieldName("foo").setScript(new Script("foo")).build()); () -> new MultiValuesSourceFieldConfig.Builder().setFieldName("foo").setScript(new Script("foo")).build());
assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be configured. Please specify one or the other.")); assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be configured. Please specify one or the other."));
} }
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
}
} }

View File

@ -55,7 +55,8 @@ public class TopMetricsAggregationBuilder extends AbstractAggregationBuilder<Top
PARSER.declareField(constructorArg(), (p, n) -> SortBuilder.fromXContent(p), SORT_FIELD, PARSER.declareField(constructorArg(), (p, n) -> SortBuilder.fromXContent(p), SORT_FIELD,
ObjectParser.ValueType.OBJECT_ARRAY_OR_STRING); ObjectParser.ValueType.OBJECT_ARRAY_OR_STRING);
PARSER.declareInt(optionalConstructorArg(), SIZE_FIELD); PARSER.declareInt(optionalConstructorArg(), SIZE_FIELD);
ContextParser<Void, MultiValuesSourceFieldConfig.Builder> metricParser = MultiValuesSourceFieldConfig.PARSER.apply(true, false); ContextParser<Void, MultiValuesSourceFieldConfig.Builder> metricParser =
MultiValuesSourceFieldConfig.parserBuilder(true, false, false);
PARSER.declareObjectArray(constructorArg(), (p, n) -> metricParser.parse(p, null).build(), METRIC_FIELD); PARSER.declareObjectArray(constructorArg(), (p, n) -> metricParser.parse(p, null).build(), METRIC_FIELD);
} }

View File

@ -12,11 +12,13 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.support.FieldContext;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder; import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory; import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
@ -41,11 +43,10 @@ public class TTestAggregationBuilder extends MultiValuesSourceAggregationBuilder
static { static {
MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC); MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC);
MultiValuesSourceParseHelper.declareField(A_FIELD.getPreferredName(), PARSER, true, false); MultiValuesSourceParseHelper.declareField(A_FIELD.getPreferredName(), PARSER, true, false, true);
MultiValuesSourceParseHelper.declareField(B_FIELD.getPreferredName(), PARSER, true, false); MultiValuesSourceParseHelper.declareField(B_FIELD.getPreferredName(), PARSER, true, false, true);
PARSER.declareString(TTestAggregationBuilder::testType, TYPE_FIELD); PARSER.declareString(TTestAggregationBuilder::testType, TYPE_FIELD);
PARSER.declareInt(TTestAggregationBuilder::tails, TAILS_FIELD); PARSER.declareInt(TTestAggregationBuilder::tails, TAILS_FIELD);
} }
private TTestType testType = TTestType.HETEROSCEDASTIC; private TTestType testType = TTestType.HETEROSCEDASTIC;
@ -117,10 +118,26 @@ public class TTestAggregationBuilder extends MultiValuesSourceAggregationBuilder
protected MultiValuesSourceAggregatorFactory<ValuesSource.Numeric> innerBuild( protected MultiValuesSourceAggregatorFactory<ValuesSource.Numeric> innerBuild(
QueryShardContext queryShardContext, QueryShardContext queryShardContext,
Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs, Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs,
Map<String, QueryBuilder> filters,
DocValueFormat format, DocValueFormat format,
AggregatorFactory parent, AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException { AggregatorFactories.Builder subFactoriesBuilder) throws IOException {
return new TTestAggregatorFactory(name, configs, testType, tails, format, queryShardContext, parent, subFactoriesBuilder, metadata); QueryBuilder filterA = filters.get(A_FIELD.getPreferredName());
QueryBuilder filterB = filters.get(B_FIELD.getPreferredName());
if (filterA == null && filterB == null) {
FieldContext fieldContextA = configs.get(A_FIELD.getPreferredName()).fieldContext();
FieldContext fieldContextB = configs.get(B_FIELD.getPreferredName()).fieldContext();
if (fieldContextA != null && fieldContextB != null) {
if (fieldContextA.field().equals(fieldContextB.field())) {
throw new IllegalArgumentException("The same field [" + fieldContextA.field() +
"] is used for both population but no filters are specified.");
}
}
}
return new TTestAggregatorFactory(name, configs, testType, tails,
filterA, filterB, format, queryShardContext, parent,
subFactoriesBuilder, metadata);
} }
@Override @Override

View File

@ -6,8 +6,15 @@
package org.elasticsearch.xpack.analytics.ttest; package org.elasticsearch.xpack.analytics.ttest;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationInitializationException;
import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.aggregations.AggregatorFactory;
@ -24,14 +31,20 @@ class TTestAggregatorFactory extends MultiValuesSourceAggregatorFactory<ValuesSo
private final TTestType testType; private final TTestType testType;
private final int tails; private final int tails;
private final Query filterA;
private final Query filterB;
private Tuple<Weight, Weight> weights;
TTestAggregatorFactory(String name, Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs, TTestType testType, int tails, TTestAggregatorFactory(String name, Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs, TTestType testType, int tails,
QueryBuilder filterA, QueryBuilder filterB,
DocValueFormat format, QueryShardContext queryShardContext, AggregatorFactory parent, DocValueFormat format, QueryShardContext queryShardContext, AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder, AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metadata) throws IOException { Map<String, Object> metadata) throws IOException {
super(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metadata); super(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metadata);
this.testType = testType; this.testType = testType;
this.tails = tails; this.tails = tails;
this.filterA = filterA == null ? null : filterA.toQuery(queryShardContext);
this.filterB = filterB == null ? null : filterB.toQuery(queryShardContext);
} }
@Override @Override
@ -42,9 +55,9 @@ class TTestAggregatorFactory extends MultiValuesSourceAggregatorFactory<ValuesSo
case PAIRED: case PAIRED:
return new PairedTTestAggregator(name, null, tails, format, searchContext, parent, metadata); return new PairedTTestAggregator(name, null, tails, format, searchContext, parent, metadata);
case HOMOSCEDASTIC: case HOMOSCEDASTIC:
return new UnpairedTTestAggregator(name, null, tails, true, format, searchContext, parent, metadata); return new UnpairedTTestAggregator(name, null, tails, true, this::getWeights, format, searchContext, parent, metadata);
case HETEROSCEDASTIC: case HETEROSCEDASTIC:
return new UnpairedTTestAggregator(name, null, tails, false, format, searchContext, parent, metadata); return new UnpairedTTestAggregator(name, null, tails, false, this::getWeights, format, searchContext, parent, metadata);
default: default:
throw new IllegalArgumentException("Unsupported t-test type " + testType); throw new IllegalArgumentException("Unsupported t-test type " + testType);
} }
@ -64,13 +77,46 @@ class TTestAggregatorFactory extends MultiValuesSourceAggregatorFactory<ValuesSo
} }
switch (testType) { switch (testType) {
case PAIRED: case PAIRED:
if (filterA != null || filterB != null) {
throw new IllegalArgumentException("Paired t-test doesn't support filters");
}
return new PairedTTestAggregator(name, numericMultiVS, tails, format, searchContext, parent, metadata); return new PairedTTestAggregator(name, numericMultiVS, tails, format, searchContext, parent, metadata);
case HOMOSCEDASTIC: case HOMOSCEDASTIC:
return new UnpairedTTestAggregator(name, numericMultiVS, tails, true, format, searchContext, parent, metadata); return new UnpairedTTestAggregator(name, numericMultiVS, tails, true, this::getWeights, format, searchContext, parent,
metadata);
case HETEROSCEDASTIC: case HETEROSCEDASTIC:
return new UnpairedTTestAggregator(name, numericMultiVS, tails, false, format, searchContext, parent, metadata); return new UnpairedTTestAggregator(name, numericMultiVS, tails, false, this::getWeights, format, searchContext,
parent, metadata);
default: default:
throw new IllegalArgumentException("Unsupported t-test type " + testType); throw new IllegalArgumentException("Unsupported t-test type " + testType);
} }
} }
/**
* Returns the {@link Weight}s for this filters, creating it if
* necessary. This is done lazily so that the {@link Weight} is only created
* if the aggregation collects documents reducing the overhead of the
* aggregation in the case where no documents are collected.
*
* Note that as aggregations are initialsed and executed in a serial manner,
* no concurrency considerations are necessary here.
*/
public Tuple<Weight, Weight> getWeights() {
if (weights == null) {
weights = new Tuple<>(getWeight(filterA), getWeight(filterB));
}
return weights;
}
public Weight getWeight(Query filter) {
if (filter != null) {
IndexSearcher contextSearcher = queryShardContext.searcher();
try {
return contextSearcher.createWeight(contextSearcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1f);
} catch (IOException e) {
throw new AggregationInitializationException("Failed to initialize filter", e);
}
}
return null;
}
} }

View File

@ -7,7 +7,11 @@
package org.elasticsearch.xpack.analytics.ttest; package org.elasticsearch.xpack.analytics.ttest;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.DocValueFormat;
@ -20,6 +24,7 @@ import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.analytics.ttest.TTestAggregationBuilder.A_FIELD; import static org.elasticsearch.xpack.analytics.ttest.TTestAggregationBuilder.A_FIELD;
import static org.elasticsearch.xpack.analytics.ttest.TTestAggregationBuilder.B_FIELD; import static org.elasticsearch.xpack.analytics.ttest.TTestAggregationBuilder.B_FIELD;
@ -28,14 +33,16 @@ public class UnpairedTTestAggregator extends TTestAggregator<UnpairedTTestState>
private final TTestStatsBuilder a; private final TTestStatsBuilder a;
private final TTestStatsBuilder b; private final TTestStatsBuilder b;
private final boolean homoscedastic; private final boolean homoscedastic;
private final Supplier<Tuple<Weight, Weight>> weightsSupplier;
UnpairedTTestAggregator(String name, MultiValuesSource.NumericMultiValuesSource valuesSources, int tails, boolean homoscedastic, UnpairedTTestAggregator(String name, MultiValuesSource.NumericMultiValuesSource valuesSources, int tails, boolean homoscedastic,
DocValueFormat format, SearchContext context, Aggregator parent, Supplier<Tuple<Weight, Weight>> weightsSupplier, DocValueFormat format, SearchContext context,
Map<String, Object> metadata) throws IOException { Aggregator parent, Map<String, Object> metadata) throws IOException {
super(name, valuesSources, tails, format, context, parent, metadata); super(name, valuesSources, tails, format, context, parent, metadata);
BigArrays bigArrays = context.bigArrays(); BigArrays bigArrays = context.bigArrays();
a = new TTestStatsBuilder(bigArrays); a = new TTestStatsBuilder(bigArrays);
b = new TTestStatsBuilder(bigArrays); b = new TTestStatsBuilder(bigArrays);
this.weightsSupplier = weightsSupplier;
this.homoscedastic = homoscedastic; this.homoscedastic = homoscedastic;
} }
@ -67,6 +74,9 @@ public class UnpairedTTestAggregator extends TTestAggregator<UnpairedTTestState>
final CompensatedSum compSumOfSqrA = new CompensatedSum(0, 0); final CompensatedSum compSumOfSqrA = new CompensatedSum(0, 0);
final CompensatedSum compSumB = new CompensatedSum(0, 0); final CompensatedSum compSumB = new CompensatedSum(0, 0);
final CompensatedSum compSumOfSqrB = new CompensatedSum(0, 0); final CompensatedSum compSumOfSqrB = new CompensatedSum(0, 0);
final Tuple<Weight, Weight> weights = weightsSupplier.get();
final Bits bitsA = getBits(ctx, weights.v1());
final Bits bitsB = getBits(ctx, weights.v2());
return new LeafBucketCollectorBase(sub, docAValues) { return new LeafBucketCollectorBase(sub, docAValues) {
@ -82,14 +92,25 @@ public class UnpairedTTestAggregator extends TTestAggregator<UnpairedTTestState>
@Override @Override
public void collect(int doc, long bucket) throws IOException { public void collect(int doc, long bucket) throws IOException {
a.grow(bigArrays, bucket + 1); if (bitsA == null || bitsA.get(doc)) {
b.grow(bigArrays, bucket + 1); a.grow(bigArrays, bucket + 1);
processValues(doc, bucket, docAValues, compSumA, compSumOfSqrA, a); processValues(doc, bucket, docAValues, compSumA, compSumOfSqrA, a);
processValues(doc, bucket, docBValues, compSumB, compSumOfSqrB, b); }
if (bitsB == null || bitsB.get(doc)) {
processValues(doc, bucket, docBValues, compSumB, compSumOfSqrB, b);
b.grow(bigArrays, bucket + 1);
}
} }
}; };
} }
private Bits getBits(LeafReaderContext ctx, Weight weight) throws IOException {
if (weight == null) {
return null;
}
return Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), weight.scorerSupplier(ctx));
}
@Override @Override
public void doClose() { public void doClose() {
Releasables.close(a, b); Releasables.close(a, b);

View File

@ -7,10 +7,14 @@
package org.elasticsearch.xpack.analytics.ttest; package org.elasticsearch.xpack.analytics.ttest;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.script.Script; import org.elasticsearch.script.Script;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.BaseAggregationBuilder; import org.elasticsearch.search.aggregations.BaseAggregationBuilder;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
@ -18,8 +22,10 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
import org.junit.Before; import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static java.util.Collections.singletonList;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
public class TTestAggregationBuilderTests extends AbstractSerializingTestCase<TTestAggregationBuilder> { public class TTestAggregationBuilderTests extends AbstractSerializingTestCase<TTestAggregationBuilder> {
@ -30,14 +36,6 @@ public class TTestAggregationBuilderTests extends AbstractSerializingTestCase<TT
aggregationName = randomAlphaOfLength(10); aggregationName = randomAlphaOfLength(10);
} }
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(singletonList(new NamedXContentRegistry.Entry(
BaseAggregationBuilder.class,
new ParseField(TTestAggregationBuilder.NAME),
(p, n) -> TTestAggregationBuilder.PARSER.apply(p, (String) n))));
}
@Override @Override
protected TTestAggregationBuilder doParseInstance(XContentParser parser) throws IOException { protected TTestAggregationBuilder doParseInstance(XContentParser parser) throws IOException {
assertSame(XContentParser.Token.START_OBJECT, parser.nextToken()); assertSame(XContentParser.Token.START_OBJECT, parser.nextToken());
@ -52,26 +50,33 @@ public class TTestAggregationBuilderTests extends AbstractSerializingTestCase<TT
@Override @Override
protected TTestAggregationBuilder createTestInstance() { protected TTestAggregationBuilder createTestInstance() {
MultiValuesSourceFieldConfig aConfig; MultiValuesSourceFieldConfig.Builder aConfig;
TTestType tTestType = randomFrom(TTestType.values());
if (randomBoolean()) { if (randomBoolean()) {
aConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("a_field").build(); aConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("a_field");
} else { } else {
aConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10))).build(); aConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10)));
} }
MultiValuesSourceFieldConfig bConfig; MultiValuesSourceFieldConfig.Builder bConfig;
if (randomBoolean()) { if (randomBoolean()) {
bConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("b_field").build(); bConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("b_field");
} else { } else {
bConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10))).build(); bConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10)));
}
if (tTestType != TTestType.PAIRED && randomBoolean()) {
aConfig.setFilter(QueryBuilders.queryStringQuery(randomAlphaOfLength(10)));
}
if (tTestType != TTestType.PAIRED && randomBoolean()) {
bConfig.setFilter(QueryBuilders.queryStringQuery(randomAlphaOfLength(10)));
} }
TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder(aggregationName) TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder(aggregationName)
.a(aConfig) .a(aConfig.build())
.b(bConfig); .b(bConfig.build());
if (randomBoolean()) { if (randomBoolean()) {
aggregationBuilder.tails(randomIntBetween(1, 2)); aggregationBuilder.tails(randomIntBetween(1, 2));
} }
if (randomBoolean()) { if (tTestType != TTestType.HETEROSCEDASTIC || randomBoolean()) {
aggregationBuilder.testType(randomFrom(TTestType.values())); aggregationBuilder.testType(randomFrom(tTestType));
} }
return aggregationBuilder; return aggregationBuilder;
} }
@ -80,5 +85,21 @@ public class TTestAggregationBuilderTests extends AbstractSerializingTestCase<TT
protected Writeable.Reader<TTestAggregationBuilder> instanceReader() { protected Writeable.Reader<TTestAggregationBuilder> instanceReader() {
return TTestAggregationBuilder::new; return TTestAggregationBuilder::new;
} }
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.add(new NamedXContentRegistry.Entry(
BaseAggregationBuilder.class,
new ParseField(TTestAggregationBuilder.NAME),
(p, n) -> TTestAggregationBuilder.PARSER.apply(p, (String) n)));
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
} }

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.analytics.ttest; package org.elasticsearch.xpack.analytics.ttest;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
@ -21,6 +22,7 @@ import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.script.MockScriptEngine; import org.elasticsearch.script.MockScriptEngine;
import org.elasticsearch.script.Script; import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.script.ScriptEngine;
@ -56,11 +58,15 @@ public class TTestAggregatorTests extends AggregatorTestCase {
*/ */
public static final String ADD_HALF_SCRIPT = "add_one"; public static final String ADD_HALF_SCRIPT = "add_one";
public static final String TERM_FILTERING = "term_filtering";
@Override @Override
protected AggregationBuilder createAggBuilderForTypeTest(MappedFieldType fieldType, String fieldName) { protected AggregationBuilder createAggBuilderForTypeTest(MappedFieldType fieldType, String fieldName) {
return new TTestAggregationBuilder("foo") return new TTestAggregationBuilder("foo")
.a(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName).build()) .a(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName)
.b(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName).build()); .setFilter(QueryBuilders.rangeQuery(fieldName).lt(10)).build())
.b(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName)
.setFilter(QueryBuilders.rangeQuery(fieldName).gte(10)).build());
} }
@Override @Override
@ -71,11 +77,18 @@ public class TTestAggregatorTests extends AggregatorTestCase {
LeafDocLookup leafDocLookup = (LeafDocLookup) vars.get("doc"); LeafDocLookup leafDocLookup = (LeafDocLookup) vars.get("doc");
String fieldname = (String) vars.get("fieldname"); String fieldname = (String) vars.get("fieldname");
ScriptDocValues<?> scriptDocValues = leafDocLookup.get(fieldname); ScriptDocValues<?> scriptDocValues = leafDocLookup.get(fieldname);
double val = ((Number) scriptDocValues.get(0)).doubleValue(); return ((Number) scriptDocValues.get(0)).doubleValue() + 0.5;
if (val == 1) { });
val += 0.0000001;
scripts.put(TERM_FILTERING, vars -> {
LeafDocLookup leafDocLookup = (LeafDocLookup) vars.get("doc");
int term = (Integer) vars.get("term");
ScriptDocValues<?> termDocValues = leafDocLookup.get("term");
int currentTerm = ((Number) termDocValues.get(0)).intValue();
if (currentTerm == term) {
return ((Number) leafDocLookup.get("field").get(0)).doubleValue();
} }
return val + 0.5; return null;
}); });
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
@ -134,6 +147,26 @@ public class TTestAggregatorTests extends AggregatorTestCase {
ex.getMessage()); ex.getMessage());
} }
public void testSameFieldAndNoFilters() {
TTestType tTestType = randomFrom(TTestType.values());
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType.setName("field");
TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder("t_test")
.a(new MultiValuesSourceFieldConfig.Builder().setFieldName("field").setMissing(100).build())
.b(new MultiValuesSourceFieldConfig.Builder().setFieldName("field").setMissing(100).build())
.testType(tTestType);
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> {
iw.addDocument(singleton(new SortedNumericDocValuesField("field", 102)));
iw.addDocument(singleton(new SortedNumericDocValuesField("field", 99)));
}, tTest -> fail("Should have thrown exception"), fieldType)
);
assertEquals(
"The same field [field] is used for both population but no filters are specified.",
ex.getMessage());
}
public void testMultipleUnpairedValues() throws IOException { public void testMultipleUnpairedValues() throws IOException {
TTestType tTestType = randomFrom(TTestType.HETEROSCEDASTIC, TTestType.HOMOSCEDASTIC); TTestType tTestType = randomFrom(TTestType.HETEROSCEDASTIC, TTestType.HOMOSCEDASTIC);
testCase(new MatchAllDocsQuery(), tTestType, iw -> { testCase(new MatchAllDocsQuery(), tTestType, iw -> {
@ -143,6 +176,15 @@ public class TTestAggregatorTests extends AggregatorTestCase {
}, tTest -> assertEquals(tTestType == TTestType.HETEROSCEDASTIC ? 0.0607303911 : 0.01718374671, tTest.getValue(), 0.000001)); }, tTest -> assertEquals(tTestType == TTestType.HETEROSCEDASTIC ? 0.0607303911 : 0.01718374671, tTest.getValue(), 0.000001));
} }
public void testUnpairedValuesWithFilters() throws IOException {
TTestType tTestType = randomFrom(TTestType.HETEROSCEDASTIC, TTestType.HOMOSCEDASTIC);
testCase(new MatchAllDocsQuery(), tTestType, iw -> {
iw.addDocument(asList(new SortedNumericDocValuesField("a", 102), new SortedNumericDocValuesField("a", 103),
new SortedNumericDocValuesField("b", 89)));
iw.addDocument(asList(new SortedNumericDocValuesField("a", 99), new SortedNumericDocValuesField("b", 93)));
}, tTest -> assertEquals(tTestType == TTestType.HETEROSCEDASTIC ? 0.0607303911 : 0.01718374671, tTest.getValue(), 0.000001));
}
public void testMissingValues() throws IOException { public void testMissingValues() throws IOException {
TTestType tTestType = randomFrom(TTestType.values()); TTestType tTestType = randomFrom(TTestType.values());
testCase(new MatchAllDocsQuery(), tTestType, iw -> { testCase(new MatchAllDocsQuery(), tTestType, iw -> {
@ -426,12 +468,12 @@ public class TTestAggregatorTests extends AggregatorTestCase {
a(fieldInA ? a : b).b(fieldInA ? b : a).testType(tTestType); a(fieldInA ? a : b).b(fieldInA ? b : a).testType(tTestType);
testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> {
iw.addDocument(singleton(new NumericDocValuesField("field", 1))); iw.addDocument(singleton(new NumericDocValuesField("field", 1)));
iw.addDocument(singleton(new NumericDocValuesField("field", 2))); iw.addDocument(singleton(new NumericDocValuesField("field", 2)));
iw.addDocument(singleton(new NumericDocValuesField("field", 3))); iw.addDocument(singleton(new NumericDocValuesField("field", 3)));
}, (Consumer<InternalTTest>) tTest -> { }, (Consumer<InternalTTest>) tTest -> {
assertEquals(tTestType == TTestType.PAIRED ? 0 : 0.5733922538, tTest.getValue(), 0.000001); assertEquals(tTestType == TTestType.PAIRED ? 0 : 0.5733922538, tTest.getValue(), 0.000001);
}, fieldType); }, fieldType);
} }
public void testPaired() throws IOException { public void testPaired() throws IOException {
@ -484,7 +526,6 @@ public class TTestAggregatorTests extends AggregatorTestCase {
}, fieldType1, fieldType2); }, fieldType1, fieldType2);
} }
public void testHeteroscedastic() throws IOException { public void testHeteroscedastic() throws IOException {
MappedFieldType fieldType1 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER); MappedFieldType fieldType1 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType1.setName("a"); fieldType1.setName("a");
@ -512,6 +553,93 @@ public class TTestAggregatorTests extends AggregatorTestCase {
}, fieldType1, fieldType2); }, fieldType1, fieldType2);
} }
public void testFiltered() throws IOException {
TTestType tTestType = randomFrom(TTestType.values());
MappedFieldType fieldType1 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType1.setName("a");
MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType2.setName("b");
TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder("t_test")
.a(new MultiValuesSourceFieldConfig.Builder().setFieldName("a").setFilter(QueryBuilders.termQuery("b", 1)).build())
.b(new MultiValuesSourceFieldConfig.Builder().setFieldName("a").setFilter(QueryBuilders.termQuery("b", 2)).build())
.testType(tTestType);
int tails = randomIntBetween(1, 2);
if (tails == 1 || randomBoolean()) {
aggregationBuilder.tails(tails);
}
CheckedConsumer<RandomIndexWriter, IOException> buildIndex = iw -> {
iw.addDocument(asList(new NumericDocValuesField("a", 102), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 99), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 111), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 97), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 101), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 99), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 89), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 93), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 72), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 98), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 102), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 98), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 189), new IntPoint("b", 3)));
iw.addDocument(asList(new NumericDocValuesField("a", 193), new IntPoint("b", 3)));
iw.addDocument(asList(new NumericDocValuesField("a", 172), new IntPoint("b", 3)));
iw.addDocument(asList(new NumericDocValuesField("a", 198), new IntPoint("b", 3)));
iw.addDocument(asList(new NumericDocValuesField("a", 1102), new IntPoint("b", 3)));
iw.addDocument(asList(new NumericDocValuesField("a", 198), new IntPoint("b", 3)));
};
if (tTestType == TTestType.PAIRED) {
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, tTest -> fail("Should have thrown exception"),
fieldType1, fieldType2)
);
assertEquals("Paired t-test doesn't support filters", ex.getMessage());
} else {
testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, (Consumer<InternalTTest>) ttest -> {
if (tTestType == TTestType.HOMOSCEDASTIC) {
assertEquals(0.03928288693 * tails, ttest.getValue(), 0.00001);
} else {
assertEquals(0.04538666214 * tails, ttest.getValue(), 0.00001);
}
}, fieldType1, fieldType2);
}
}
public void testFilterByFilterOrScript() throws IOException {
boolean fieldInA = randomBoolean();
TTestType tTestType = randomFrom(TTestType.HOMOSCEDASTIC, TTestType.HETEROSCEDASTIC);
MappedFieldType fieldType1 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType1.setName("field");
MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType2.setName("term");
boolean filterTermOne = randomBoolean();
MultiValuesSourceFieldConfig.Builder a = new MultiValuesSourceFieldConfig.Builder().setFieldName("field").setFilter(
QueryBuilders.termQuery("term", filterTermOne? 1 : 2)
);
MultiValuesSourceFieldConfig.Builder b = new MultiValuesSourceFieldConfig.Builder().setScript(
new Script(ScriptType.INLINE, MockScriptEngine.NAME, TERM_FILTERING, Collections.singletonMap("term", filterTermOne? 2 : 1))
);
TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder("t_test").
a(fieldInA ? a.build() : b.build()).b(fieldInA ? b.build() : a.build()).testType(tTestType);
testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> {
iw.addDocument(asList(new NumericDocValuesField("field", 1), new IntPoint("term", 1), new NumericDocValuesField("term", 1)));
iw.addDocument(asList(new NumericDocValuesField("field", 2), new IntPoint("term", 1), new NumericDocValuesField("term", 1)));
iw.addDocument(asList(new NumericDocValuesField("field", 3), new IntPoint("term", 1), new NumericDocValuesField("term", 1)));
iw.addDocument(asList(new NumericDocValuesField("field", 4), new IntPoint("term", 2), new NumericDocValuesField("term", 2)));
iw.addDocument(asList(new NumericDocValuesField("field", 5), new IntPoint("term", 2), new NumericDocValuesField("term", 2)));
iw.addDocument(asList(new NumericDocValuesField("field", 6), new IntPoint("term", 2), new NumericDocValuesField("term", 2)));
}, (Consumer<InternalTTest>) tTest -> {
assertEquals(0.02131164113, tTest.getValue(), 0.000001);
}, fieldType1, fieldType2);
}
private void testCase(Query query, TTestType type, private void testCase(Query query, TTestType type,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex, CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalTTest> verify) throws IOException { Consumer<InternalTTest> verify) throws IOException {

View File

@ -0,0 +1,72 @@
---
setup:
- do:
bulk:
index: test
refresh: true
body:
- '{"index": {}}'
- '{"v1": 15.2, "v2": 15.9, "str": "a"}'
- '{"index": {}}'
- '{"v1": 15.3, "v2": 15.9, "str": "a"}'
- '{"index": {}}'
- '{"v1": 16.0, "v2": 15.2, "str": "b"}'
- '{"index": {}}'
- '{"v1": 15.1, "v2": 15.5, "str": "b"}'
---
"heteroscedastic t-test":
- do:
search:
size: 0
index: "test"
body:
aggs:
ttest:
t_test:
a:
field: v1
b:
field: v2
- match: { aggregations.ttest.value: 0.43066659210472646 }
---
"paired t-test":
- do:
search:
size: 0
index: "test"
body:
aggs:
ttest:
t_test:
a:
field: v1
b:
field: v2
type: paired
- match: { aggregations.ttest.value: 0.5632529432617406 }
---
"homoscedastic t-test with filters":
- do:
search:
size: 0
index: "test"
body:
aggs:
ttest:
t_test:
a:
field: v1
filter:
term:
str.keyword: a
b:
field: v1
filter:
term:
str.keyword: b
type: homoscedastic
- match: { aggregations.ttest.value: 0.5757355806262943 }