Added unit tests for MatrixStatsAggregator

This commit is contained in:
Martijn van Groningen 2017-05-23 09:14:53 +02:00
parent b2ccb6b0a8
commit 34093735e3
No known key found for this signature in database
GPG Key ID: AB236F4FCF2AF12A
5 changed files with 108 additions and 5 deletions

View File

@ -139,6 +139,10 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
return results.getCorrelation(fieldX, fieldY);
}
RunningStats getStats() {
return stats;
}
MatrixStatsResults getResults() {
return results;
}

View File

@ -41,14 +41,14 @@ import java.util.Map;
/**
* Metric Aggregation for computing the pearson product correlation coefficient between multiple fields
**/
public class MatrixStatsAggregator extends MetricsAggregator {
final class MatrixStatsAggregator extends MetricsAggregator {
/** Multiple ValuesSource with field names */
final NumericMultiValuesSource valuesSources;
private final NumericMultiValuesSource valuesSources;
/** array of descriptive stats, per shard, needed to compute the correlation */
ObjectArray<RunningStats> stats;
public MatrixStatsAggregator(String name, Map<String, ValuesSource.Numeric> valuesSources, SearchContext context,
MatrixStatsAggregator(String name, Map<String, ValuesSource.Numeric> valuesSources, SearchContext context,
Aggregator parent, MultiValueMode multiValueMode, List<PipelineAggregator> pipelineAggregators,
Map<String,Object> metaData) throws IOException {
super(name, context, parent, pipelineAggregators, metaData);

View File

@ -32,12 +32,12 @@ import java.io.IOException;
import java.util.List;
import java.util.Map;
public class MatrixStatsAggregatorFactory
final class MatrixStatsAggregatorFactory
extends MultiValuesSourceAggregatorFactory<ValuesSource.Numeric, MatrixStatsAggregatorFactory> {
private final MultiValueMode multiValueMode;
public MatrixStatsAggregatorFactory(String name,
MatrixStatsAggregatorFactory(String name,
Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs, MultiValueMode multiValueMode,
SearchContext context, AggregatorFactory<?> parent, AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {

View File

@ -0,0 +1,96 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.matrix.stats;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import java.util.Arrays;
import java.util.Collections;
public class MatrixStatsAggregatorTests extends AggregatorTestCase {
public void testNoData() throws Exception {
MappedFieldType ft =
new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
ft.setName("field");
try (Directory directory = newDirectory();
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
if (randomBoolean()) {
indexWriter.addDocument(Collections.singleton(new StringField("another_field", "value", Field.Store.NO)));
}
try (IndexReader reader = indexWriter.getReader()) {
IndexSearcher searcher = new IndexSearcher(reader);
MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg")
.fields(Collections.singletonList("field"));
InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ft);
assertNull(stats.getStats());
}
}
}
public void testTwoFields() throws Exception {
String fieldA = "a";
MappedFieldType ftA = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
ftA.setName(fieldA);
String fieldB = "b";
MappedFieldType ftB = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
ftB.setName(fieldB);
try (Directory directory = newDirectory();
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
int numDocs = scaledRandomIntBetween(8192, 16384);
Double[] fieldAValues = new Double[numDocs];
Double[] fieldBValues = new Double[numDocs];
for (int docId = 0; docId < numDocs; docId++) {
Document document = new Document();
fieldAValues[docId] = randomDouble();
document.add(new SortedNumericDocValuesField(fieldA, NumericUtils.doubleToSortableLong(fieldAValues[docId])));
fieldBValues[docId] = randomDouble();
document.add(new SortedNumericDocValuesField(fieldB, NumericUtils.doubleToSortableLong(fieldBValues[docId])));
indexWriter.addDocument(document);
}
MultiPassStats multiPassStats = new MultiPassStats(fieldA, fieldB);
multiPassStats.computeStats(Arrays.asList(fieldAValues), Arrays.asList(fieldBValues));
try (IndexReader reader = indexWriter.getReader()) {
IndexSearcher searcher = new IndexSearcher(reader);
MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg")
.fields(Arrays.asList(fieldA, fieldB));
InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ftA, ftB);
multiPassStats.assertNearlyEqual(new MatrixStatsResults(stats.getStats()));
}
}
}
}

View File

@ -110,6 +110,9 @@ public abstract class AggregatorTestCase extends ESTestCase {
QueryShardContext queryShardContext = queryShardContextMock(mapperService, fieldTypes, circuitBreakerService);
when(searchContext.getQueryShardContext()).thenReturn(queryShardContext);
for (MappedFieldType fieldType : fieldTypes) {
when(searchContext.smartNameFieldType(fieldType.name())).thenReturn(fieldType);
}
return aggregationBuilder.build(searchContext, null);
}