mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-27 18:38:41 +00:00
Added unit tests for MatrixStatsAggregator
This commit is contained in:
parent
b2ccb6b0a8
commit
34093735e3
@ -139,6 +139,10 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
|
||||
return results.getCorrelation(fieldX, fieldY);
|
||||
}
|
||||
|
||||
RunningStats getStats() {
|
||||
return stats;
|
||||
}
|
||||
|
||||
MatrixStatsResults getResults() {
|
||||
return results;
|
||||
}
|
||||
|
@ -41,14 +41,14 @@ import java.util.Map;
|
||||
/**
|
||||
* Metric Aggregation for computing the pearson product correlation coefficient between multiple fields
|
||||
**/
|
||||
public class MatrixStatsAggregator extends MetricsAggregator {
|
||||
final class MatrixStatsAggregator extends MetricsAggregator {
|
||||
/** Multiple ValuesSource with field names */
|
||||
final NumericMultiValuesSource valuesSources;
|
||||
private final NumericMultiValuesSource valuesSources;
|
||||
|
||||
/** array of descriptive stats, per shard, needed to compute the correlation */
|
||||
ObjectArray<RunningStats> stats;
|
||||
|
||||
public MatrixStatsAggregator(String name, Map<String, ValuesSource.Numeric> valuesSources, SearchContext context,
|
||||
MatrixStatsAggregator(String name, Map<String, ValuesSource.Numeric> valuesSources, SearchContext context,
|
||||
Aggregator parent, MultiValueMode multiValueMode, List<PipelineAggregator> pipelineAggregators,
|
||||
Map<String,Object> metaData) throws IOException {
|
||||
super(name, context, parent, pipelineAggregators, metaData);
|
||||
|
@ -32,12 +32,12 @@ import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class MatrixStatsAggregatorFactory
|
||||
final class MatrixStatsAggregatorFactory
|
||||
extends MultiValuesSourceAggregatorFactory<ValuesSource.Numeric, MatrixStatsAggregatorFactory> {
|
||||
|
||||
private final MultiValueMode multiValueMode;
|
||||
|
||||
public MatrixStatsAggregatorFactory(String name,
|
||||
MatrixStatsAggregatorFactory(String name,
|
||||
Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs, MultiValueMode multiValueMode,
|
||||
SearchContext context, AggregatorFactory<?> parent, AggregatorFactories.Builder subFactoriesBuilder,
|
||||
Map<String, Object> metaData) throws IOException {
|
||||
|
@ -0,0 +1,96 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.search.aggregations.matrix.stats;
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.SortedNumericDocValuesField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.NumericUtils;
|
||||
import org.elasticsearch.index.mapper.MappedFieldType;
|
||||
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||
import org.elasticsearch.search.aggregations.AggregatorTestCase;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
public class MatrixStatsAggregatorTests extends AggregatorTestCase {
|
||||
|
||||
public void testNoData() throws Exception {
|
||||
MappedFieldType ft =
|
||||
new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
|
||||
ft.setName("field");
|
||||
|
||||
try (Directory directory = newDirectory();
|
||||
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
|
||||
if (randomBoolean()) {
|
||||
indexWriter.addDocument(Collections.singleton(new StringField("another_field", "value", Field.Store.NO)));
|
||||
}
|
||||
try (IndexReader reader = indexWriter.getReader()) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg")
|
||||
.fields(Collections.singletonList("field"));
|
||||
InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ft);
|
||||
assertNull(stats.getStats());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testTwoFields() throws Exception {
|
||||
String fieldA = "a";
|
||||
MappedFieldType ftA = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
|
||||
ftA.setName(fieldA);
|
||||
String fieldB = "b";
|
||||
MappedFieldType ftB = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
|
||||
ftB.setName(fieldB);
|
||||
|
||||
try (Directory directory = newDirectory();
|
||||
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
|
||||
|
||||
int numDocs = scaledRandomIntBetween(8192, 16384);
|
||||
Double[] fieldAValues = new Double[numDocs];
|
||||
Double[] fieldBValues = new Double[numDocs];
|
||||
for (int docId = 0; docId < numDocs; docId++) {
|
||||
Document document = new Document();
|
||||
fieldAValues[docId] = randomDouble();
|
||||
document.add(new SortedNumericDocValuesField(fieldA, NumericUtils.doubleToSortableLong(fieldAValues[docId])));
|
||||
|
||||
fieldBValues[docId] = randomDouble();
|
||||
document.add(new SortedNumericDocValuesField(fieldB, NumericUtils.doubleToSortableLong(fieldBValues[docId])));
|
||||
indexWriter.addDocument(document);
|
||||
}
|
||||
|
||||
MultiPassStats multiPassStats = new MultiPassStats(fieldA, fieldB);
|
||||
multiPassStats.computeStats(Arrays.asList(fieldAValues), Arrays.asList(fieldBValues));
|
||||
try (IndexReader reader = indexWriter.getReader()) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg")
|
||||
.fields(Arrays.asList(fieldA, fieldB));
|
||||
InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ftA, ftB);
|
||||
multiPassStats.assertNearlyEqual(new MatrixStatsResults(stats.getStats()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -110,6 +110,9 @@ public abstract class AggregatorTestCase extends ESTestCase {
|
||||
|
||||
QueryShardContext queryShardContext = queryShardContextMock(mapperService, fieldTypes, circuitBreakerService);
|
||||
when(searchContext.getQueryShardContext()).thenReturn(queryShardContext);
|
||||
for (MappedFieldType fieldType : fieldTypes) {
|
||||
when(searchContext.smartNameFieldType(fieldType.name())).thenReturn(fieldType);
|
||||
}
|
||||
|
||||
return aggregationBuilder.build(searchContext, null);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user