Add tests to MedianAbsoluteDeviationAggregator (#54884) (#55282)

This commit is contained in:
Christos Soulios 2020-04-16 13:46:09 +03:00 committed by GitHub
parent 6a0eebf1d7
commit 2a56a3a1f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 172 additions and 63 deletions

View File

@ -32,22 +32,37 @@ import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.script.MockScriptEngine;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.ScriptModule;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.support.AggregationInspectionHelper;
import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
import org.elasticsearch.search.aggregations.support.ValuesSourceType;
import org.hamcrest.Description;
import org.hamcrest.TypeSafeMatcher;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.IntStream;
import static java.util.Collections.singleton;
import static java.util.Collections.singletonList;
import static org.elasticsearch.search.aggregations.metrics.MedianAbsoluteDeviationAggregatorTests.ExactMedianAbsoluteDeviation.calculateMAD;
import static org.elasticsearch.search.aggregations.metrics.MedianAbsoluteDeviationAggregatorTests.IsCloseToRelative.closeToRelative;
import static org.hamcrest.Matchers.equalTo;
@ -56,6 +71,11 @@ public class MedianAbsoluteDeviationAggregatorTests extends AggregatorTestCase {
private static final int SAMPLE_MIN = -1000000;
private static final int SAMPLE_MAX = 1000000;
public static final String FIELD_NAME = "number";
/** Script to return the {@code _value} provided by aggs framework. */
private static final String VALUE_SCRIPT = "_value";
private static final String SINGLE_SCRIPT = "single";
private static <T extends IndexableField> CheckedConsumer<RandomIndexWriter, IOException> randomSample(
int size,
@ -96,10 +116,10 @@ public class MedianAbsoluteDeviationAggregatorTests extends AggregatorTestCase {
final int size = randomIntBetween(100, 1000);
final List<Long> sample = new ArrayList<>(size);
testCase(
new DocValuesFieldExistsQuery("number"),
new DocValuesFieldExistsQuery(FIELD_NAME),
randomSample(size, point -> {
sample.add(point);
return singleton(new SortedNumericDocValuesField("number", point));
return singleton(new SortedNumericDocValuesField(FIELD_NAME, point));
}),
agg -> {
assertThat(agg.getMedianAbsoluteDeviation(), closeToRelative(calculateMAD(sample)));
@ -112,10 +132,10 @@ public class MedianAbsoluteDeviationAggregatorTests extends AggregatorTestCase {
final int size = randomIntBetween(100, 1000);
final List<Long> sample = new ArrayList<>(size);
testCase(
new DocValuesFieldExistsQuery("number"),
new DocValuesFieldExistsQuery(FIELD_NAME),
randomSample(size, point -> {
sample.add(point);
return singleton(new NumericDocValuesField("number", point));
return singleton(new NumericDocValuesField(FIELD_NAME, point));
}),
agg -> {
assertThat(agg.getMedianAbsoluteDeviation(), closeToRelative(calculateMAD(sample)));
@ -130,10 +150,10 @@ public class MedianAbsoluteDeviationAggregatorTests extends AggregatorTestCase {
final int[] sample = IntStream.rangeClosed(1, 1000).toArray();
final int[] filteredSample = Arrays.stream(sample).filter(point -> point >= lowerRange && point <= upperRange).toArray();
testCase(
IntPoint.newRangeQuery("number", lowerRange, upperRange),
IntPoint.newRangeQuery(FIELD_NAME, lowerRange, upperRange),
writer -> {
for (int point : sample) {
writer.addDocument(Arrays.asList(new IntPoint("number", point), new SortedNumericDocValuesField("number", point)));
writer.addDocument(Arrays.asList(new IntPoint(FIELD_NAME, point), new SortedNumericDocValuesField(FIELD_NAME, point)));
}
},
agg -> {
@ -145,10 +165,10 @@ public class MedianAbsoluteDeviationAggregatorTests extends AggregatorTestCase {
public void testQueryFiltersAll() throws IOException {
testCase(
IntPoint.newRangeQuery("number", -1, 0),
IntPoint.newRangeQuery(FIELD_NAME, -1, 0),
writer -> {
writer.addDocument(Arrays.asList(new IntPoint("number", 1), new SortedNumericDocValuesField("number", 1)));
writer.addDocument(Arrays.asList(new IntPoint("number", 2), new SortedNumericDocValuesField("number", 2)));
writer.addDocument(Arrays.asList(new IntPoint(FIELD_NAME, 1), new SortedNumericDocValuesField(FIELD_NAME, 1)));
writer.addDocument(Arrays.asList(new IntPoint(FIELD_NAME, 2), new SortedNumericDocValuesField(FIELD_NAME, 2)));
},
agg -> {
assertThat(agg.getMedianAbsoluteDeviation(), equalTo(Double.NaN));
@ -157,34 +177,110 @@ public class MedianAbsoluteDeviationAggregatorTests extends AggregatorTestCase {
);
}
private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalMedianAbsoluteDeviation> verify) throws IOException {
public void testUnmapped() throws IOException {
MedianAbsoluteDeviationAggregationBuilder aggregationBuilder = new MedianAbsoluteDeviationAggregationBuilder("foo")
.field(FIELD_NAME);
testCase(aggregationBuilder, new DocValuesFieldExistsQuery(FIELD_NAME), iw -> {
iw.addDocument(singleton(new NumericDocValuesField(FIELD_NAME, 7)));
iw.addDocument(singleton(new NumericDocValuesField(FIELD_NAME, 1)));
}, agg -> {
assertEquals(Double.NaN, agg.getMedianAbsoluteDeviation(),0);
assertFalse(AggregationInspectionHelper.hasValue(agg));
}, null);
}
public void testUnmappedMissing() throws IOException {
MedianAbsoluteDeviationAggregationBuilder aggregationBuilder = new MedianAbsoluteDeviationAggregationBuilder("foo")
.field(FIELD_NAME)
.missing(1234);
testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> {
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 7)));
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 8)));
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 9)));
}, agg -> {
assertEquals(0, agg.getMedianAbsoluteDeviation(), 0);
assertTrue(AggregationInspectionHelper.hasValue(agg));
}, null);
}
public void testValueScript() throws IOException {
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName(FIELD_NAME);
fieldType.setHasDocValues(true);
MedianAbsoluteDeviationAggregationBuilder aggregationBuilder = new MedianAbsoluteDeviationAggregationBuilder("foo")
.field(FIELD_NAME)
.script(new Script(ScriptType.INLINE, MockScriptEngine.NAME, VALUE_SCRIPT, Collections.emptyMap()));
final int size = randomIntBetween(100, 1000);
final List<Long> sample = new ArrayList<>(size);
testCase(aggregationBuilder,
new MatchAllDocsQuery(),
randomSample(size, point -> {
sample.add(point);
return singleton(new SortedNumericDocValuesField(FIELD_NAME, point));
}),
agg -> {
assertThat(agg.getMedianAbsoluteDeviation(), closeToRelative(calculateMAD(sample)));
assertTrue(AggregationInspectionHelper.hasValue(agg));
}, fieldType);
}
public void testSingleScript() throws IOException {
MedianAbsoluteDeviationAggregationBuilder aggregationBuilder = new MedianAbsoluteDeviationAggregationBuilder("foo")
.script(new Script(ScriptType.INLINE, MockScriptEngine.NAME, SINGLE_SCRIPT, Collections.emptyMap()));
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName(FIELD_NAME);
final int size = randomIntBetween(100, 1000);
final List<Long> sample = new ArrayList<>(size);
testCase(aggregationBuilder,
new MatchAllDocsQuery(),
iw -> {
for (int i = 0; i < 10; i++) {
iw.addDocument(singleton(new NumericDocValuesField(FIELD_NAME, i + 1)));
}
},
agg -> {
assertEquals(0, agg.getMedianAbsoluteDeviation(), 0);
assertTrue(AggregationInspectionHelper.hasValue(agg));
}, fieldType);
}
private void testCase(Query query,
CheckedConsumer<RandomIndexWriter,
IOException> buildIndex,
Consumer<InternalMedianAbsoluteDeviation> verify) throws IOException {
MedianAbsoluteDeviationAggregationBuilder builder = new MedianAbsoluteDeviationAggregationBuilder("mad")
.field(FIELD_NAME)
.compression(randomDoubleBetween(20, 1000, true));
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName(FIELD_NAME);
testCase(builder, query, buildIndex, verify, fieldType);
}
private void testCase(MedianAbsoluteDeviationAggregationBuilder aggregationBuilder, Query query,
CheckedConsumer<RandomIndexWriter, IOException> indexer,
Consumer<InternalMedianAbsoluteDeviation> verify, MappedFieldType fieldType) throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
buildIndex.accept(indexWriter);
indexer.accept(indexWriter);
}
try (IndexReader indexReader = DirectoryReader.open(directory)) {
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
MedianAbsoluteDeviationAggregationBuilder builder = new MedianAbsoluteDeviationAggregationBuilder("mad")
.field("number")
.compression(randomDoubleBetween(20, 1000, true));
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName("number");
MedianAbsoluteDeviationAggregator aggregator = createAggregator(builder, indexSearcher, fieldType);
Aggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType);
aggregator.preCollection();
indexSearcher.search(query, aggregator);
aggregator.postCollection();
verify.accept((InternalMedianAbsoluteDeviation) aggregator.buildAggregation(0L));
}
}
}
public static class IsCloseToRelative extends TypeSafeMatcher<Double> {
@ -271,6 +367,30 @@ public class MedianAbsoluteDeviationAggregatorTests extends AggregatorTestCase {
}
return median;
}
}
@Override
protected List<ValuesSourceType> getSupportedValuesSourceTypes() {
return singletonList(CoreValuesSourceType.NUMERIC);
}
@Override
protected AggregationBuilder createAggBuilderForTypeTest(MappedFieldType fieldType, String fieldName) {
return new MedianAbsoluteDeviationAggregationBuilder("foo").field(fieldName);
}
@Override
protected ScriptService getMockScriptService() {
Map<String, Function<Map<String, Object>, Object>> scripts = new HashMap<>();
scripts.put(VALUE_SCRIPT, vars -> ((Number) vars.get("_value")).doubleValue() + 1);
scripts.put(SINGLE_SCRIPT, vars -> 1);
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
scripts,
Collections.emptyMap());
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
}
}

View File

@ -182,20 +182,7 @@ public class MedianAbsoluteDeviationIT extends AbstractNumericTestCase {
@Override
public void testUnmapped() throws Exception {
final SearchResponse response = client()
.prepareSearch("idx_unmapped")
.setQuery(matchAllQuery())
.addAggregation(
randomBuilder()
.field("value"))
.get();
assertHitCount(response, 0);
final MedianAbsoluteDeviation mad = response.getAggregations().get("mad");
assertThat(mad, notNullValue());
assertThat(mad.getName(), is("mad"));
assertThat(mad.getMedianAbsoluteDeviation(), is(Double.NaN));
// Test moved to MedianAbsoluteDeviationAggregatorTests.testUnmapped()
}
@Override

View File

@ -219,9 +219,9 @@ public class ValueCountAggregatorTests extends AggregatorTestCase {
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 7)));
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 8)));
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 9)));
}, card -> {
assertEquals(3, card.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(card));
}, valueCount -> {
assertEquals(3, valueCount.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(valueCount));
}, null);
}
@ -233,9 +233,9 @@ public class ValueCountAggregatorTests extends AggregatorTestCase {
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 7)));
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 8)));
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 9)));
}, card -> {
assertEquals(3, card.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(card));
}, valueCount -> {
assertEquals(3, valueCount.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(valueCount));
}, null);
}
@ -247,9 +247,9 @@ public class ValueCountAggregatorTests extends AggregatorTestCase {
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 7)));
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 8)));
iw.addDocument(singleton(new NumericDocValuesField("unrelatedField", 9)));
}, card -> {
assertEquals(3, card.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(card));
}, valueCount -> {
assertEquals(3, valueCount.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(valueCount));
}, null);
}
@ -288,15 +288,15 @@ public class ValueCountAggregatorTests extends AggregatorTestCase {
iw.addDocument(singleton(new NumericDocValuesField(FIELD_NAME, 7)));
iw.addDocument(singleton(new NumericDocValuesField(FIELD_NAME, 8)));
iw.addDocument(singleton(new NumericDocValuesField(FIELD_NAME, 9)));
}, card -> {
assertEquals(3, card.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(card));
}, valueCount -> {
assertEquals(3, valueCount.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(valueCount));
}, fieldType);
}
public void testSingleScriptNumber() throws IOException {
ValueCountAggregationBuilder aggregationBuilder = new ValueCountAggregationBuilder("name", null)
.field(FIELD_NAME);
.script(new Script(ScriptType.INLINE, MockScriptEngine.NAME, SINGLE_SCRIPT, Collections.emptyMap()));
MappedFieldType fieldType = createMappedFieldType(ValueType.NUMERIC);
fieldType.setName(FIELD_NAME);
@ -317,10 +317,11 @@ public class ValueCountAggregatorTests extends AggregatorTestCase {
doc.add(new SortedNumericDocValuesField(FIELD_NAME, 1));
doc.add(new SortedNumericDocValuesField(FIELD_NAME, 1));
iw.addDocument(doc);
}, card -> {
// note: this is 6, even though the script returns a single value. ValueCount does not de-dedupe
assertEquals(6, card.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(card));
}, valueCount -> {
// Note: The field values won't be taken into account. The script will only be called
// once per document, and only expect a count of 3
assertEquals(3, valueCount.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(valueCount));
}, fieldType);
}
@ -337,15 +338,15 @@ public class ValueCountAggregatorTests extends AggregatorTestCase {
iw.addDocument(singleton(new SortedDocValuesField(FIELD_NAME, new BytesRef("1"))));
iw.addDocument(singleton(new SortedDocValuesField(FIELD_NAME, new BytesRef("2"))));
iw.addDocument(singleton(new SortedDocValuesField(FIELD_NAME, new BytesRef("3"))));
}, card -> {
assertEquals(3, card.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(card));
}, valueCount -> {
assertEquals(3, valueCount.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(valueCount));
}, fieldType);
}
public void testSingleScriptString() throws IOException {
ValueCountAggregationBuilder aggregationBuilder = new ValueCountAggregationBuilder("name", null)
.field(FIELD_NAME);
.script(new Script(ScriptType.INLINE, MockScriptEngine.NAME, SINGLE_SCRIPT, Collections.emptyMap()));
MappedFieldType fieldType = createMappedFieldType(ValueType.STRING);
fieldType.setName(FIELD_NAME);
@ -367,10 +368,11 @@ public class ValueCountAggregatorTests extends AggregatorTestCase {
doc.add(new SortedSetDocValuesField(FIELD_NAME, new BytesRef("5")));
doc.add(new SortedSetDocValuesField(FIELD_NAME, new BytesRef("6")));
iw.addDocument(doc);
}, card -> {
// note: this is 6, even though the script returns a single value. ValueCount does not de-dedupe
assertEquals(6, card.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(card));
}, valueCount -> {
// Note: The field values won't be taken into account. The script will only be called
// once per document, and only expect a count of 3
assertEquals(3, valueCount.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(valueCount));
}, fieldType);
}