Added unit tests for InternalMatrixStats.

Also moved InternalAggregationTestCase to test-framework module in order to make use of it from other modules than core.

Relates to #22278
This commit is contained in:
Martijn van Groningen 2017-05-09 15:04:28 +02:00
parent b24326271e
commit 51c74ce547
No known key found for this signature in database
GPG Key ID: AB236F4FCF2AF12A
32 changed files with 345 additions and 161 deletions

View File

@ -20,11 +20,11 @@
package org.elasticsearch.search.aggregations.bucket;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.metrics.max.InternalMax;
import org.elasticsearch.search.aggregations.metrics.min.InternalMin;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.ArrayList;
import java.util.List;

View File

@ -21,9 +21,9 @@ package org.elasticsearch.search.aggregations.bucket.geogrid;
import org.apache.lucene.index.IndexWriter;
import org.elasticsearch.common.geo.GeoHashUtils;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.ArrayList;
import java.util.HashMap;

View File

@ -19,13 +19,11 @@
package org.elasticsearch.search.aggregations.bucket.histogram;
import org.apache.lucene.util.TestUtil;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.joda.time.DateTime;
import java.util.ArrayList;

View File

@ -22,9 +22,9 @@ package org.elasticsearch.search.aggregations.bucket.histogram;
import org.apache.lucene.util.TestUtil;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.ArrayList;
import java.util.List;

View File

@ -22,9 +22,9 @@ import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.junit.Before;
import java.util.ArrayList;

View File

@ -19,8 +19,8 @@
package org.elasticsearch.search.aggregations.bucket.significant;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.Arrays;
import java.util.HashMap;

View File

@ -19,14 +19,14 @@
package org.elasticsearch.search.aggregations.bucket.terms;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

View File

@ -21,9 +21,9 @@ package org.elasticsearch.search.aggregations.metrics;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.metrics.stats.extended.InternalExtendedStats;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.junit.Before;
import java.util.Collections;

View File

@ -20,9 +20,9 @@
package org.elasticsearch.search.aggregations.metrics;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.metrics.max.InternalMax;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.List;
import java.util.Map;

View File

@ -20,9 +20,9 @@ package org.elasticsearch.search.aggregations.metrics;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.metrics.stats.InternalStats;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.Collections;
import java.util.List;

View File

@ -20,9 +20,8 @@
package org.elasticsearch.search.aggregations.metrics.avg;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.List;
import java.util.Map;

View File

@ -24,8 +24,8 @@ import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.junit.After;
import org.junit.Before;

View File

@ -20,8 +20,8 @@
package org.elasticsearch.search.aggregations.metrics.geobounds;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.Collections;
import java.util.List;

View File

@ -21,8 +21,8 @@ package org.elasticsearch.search.aggregations.metrics.geocentroid;
import org.apache.lucene.geo.GeoEncodingUtils;
import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.elasticsearch.test.geo.RandomGeoGenerator;
import java.util.Collections;

View File

@ -20,9 +20,8 @@
package org.elasticsearch.search.aggregations.metrics.min;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.List;
import java.util.Map;

View File

@ -21,8 +21,8 @@ package org.elasticsearch.search.aggregations.metrics.percentiles;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.junit.Before;
import java.util.List;

View File

@ -22,9 +22,8 @@ package org.elasticsearch.search.aggregations.metrics.percentiles.hdr;
import org.HdrHistogram.DoubleHistogram;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.metrics.percentiles.hdr.InternalHDRPercentileRanks;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.List;
import java.util.Map;

View File

@ -21,8 +21,8 @@ package org.elasticsearch.search.aggregations.metrics.percentiles.tdigest;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.List;
import java.util.Map;

View File

@ -30,8 +30,8 @@ import org.elasticsearch.script.ScriptEngineRegistry;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.script.ScriptSettings;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.io.IOException;
import java.util.Collections;

View File

@ -20,8 +20,8 @@ package org.elasticsearch.search.aggregations.metrics.sum;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.List;
import java.util.Map;

View File

@ -30,11 +30,11 @@ import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.ArrayList;
import java.util.Arrays;

View File

@ -20,8 +20,8 @@
package org.elasticsearch.search.aggregations.metrics.valuecount;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.List;
import java.util.Map;

View File

@ -21,7 +21,7 @@ package org.elasticsearch.search.aggregations.pipeline;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.Collections;
import java.util.List;

View File

@ -21,9 +21,9 @@ package org.elasticsearch.search.aggregations.pipeline.bucketmetrics.percentile;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.metrics.percentiles.Percentile;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.Collections;
import java.util.Iterator;

View File

@ -21,8 +21,8 @@ package org.elasticsearch.search.aggregations.pipeline.derivative;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.Collections;
import java.util.List;

View File

@ -28,6 +28,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import static java.util.Collections.emptyMap;
@ -41,7 +42,7 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
private final MatrixStatsResults results;
/** per shard ctor */
protected InternalMatrixStats(String name, long count, RunningStats multiFieldStatsResults, MatrixStatsResults results,
InternalMatrixStats(String name, long count, RunningStats multiFieldStatsResults, MatrixStatsResults results,
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
super(name, pipelineAggregators, metaData);
assert count >= 0;
@ -138,6 +139,10 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
return results.getCorrelation(fieldX, fieldY);
}
MatrixStatsResults getResults() {
return results;
}
static class Fields {
public static final String FIELDS = "fields";
public static final String NAME = "name";
@ -238,4 +243,16 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
return new InternalMatrixStats(name, results.getDocCount(), runningStats, results, pipelineAggregators(), getMetaData());
}
@Override
protected int doHashCode() {
return Objects.hash(stats, results);
}
@Override
protected boolean doEquals(Object obj) {
InternalMatrixStats other = (InternalMatrixStats) obj;
return Objects.equals(this.stats, other.stats) &&
Objects.equals(this.results, other.results);
}
}

View File

@ -27,6 +27,7 @@ import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* Descriptive stats gathered per shard. Coordinating node computes final pearson product coefficient
@ -228,4 +229,18 @@ class MatrixStatsResults implements Writeable {
correlation.put(rowName, corRow);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MatrixStatsResults that = (MatrixStatsResults) o;
return Objects.equals(results, that.results) &&
Objects.equals(correlation, that.correlation);
}
@Override
public int hashCode() {
return Objects.hash(results, correlation);
}
}

View File

@ -28,6 +28,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* Descriptive stats gathered per shard. Coordinating node computes final correlation and covariance stats
@ -53,11 +54,11 @@ public class RunningStats implements Writeable, Cloneable {
/** covariance values */
protected HashMap<String, HashMap<String, Double>> covariances;
public RunningStats() {
RunningStats() {
init();
}
public RunningStats(final String[] fieldNames, final double[] fieldVals) {
RunningStats(final String[] fieldNames, final double[] fieldVals) {
if (fieldVals != null && fieldVals.length > 0) {
init();
this.add(fieldNames, fieldVals);
@ -309,4 +310,24 @@ public class RunningStats implements Writeable, Cloneable {
throw new ElasticsearchException("Error trying to create a copy of RunningStats");
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RunningStats that = (RunningStats) o;
return docCount == that.docCount &&
Objects.equals(fieldSum, that.fieldSum) &&
Objects.equals(counts, that.counts) &&
Objects.equals(means, that.means) &&
Objects.equals(variances, that.variances) &&
Objects.equals(skewness, that.skewness) &&
Objects.equals(kurtosis, that.kurtosis) &&
Objects.equals(covariances, that.covariances);
}
@Override
public int hashCode() {
return Objects.hash(docCount, fieldSum, counts, means, variances, skewness, kurtosis, covariances);
}
}

View File

@ -22,15 +22,12 @@ import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import java.util.ArrayList;
import java.util.HashMap;
import static org.hamcrest.Matchers.equalTo;
public abstract class BaseMatrixStatsTestCase extends ESTestCase {
protected final int numObs = atLeast(10000);
protected final ArrayList<Double> fieldA = new ArrayList<>(numObs);
protected final ArrayList<Double> fieldB = new ArrayList<>(numObs);
protected final MultiPassStats actualStats = new MultiPassStats();
protected final MultiPassStats actualStats = new MultiPassStats(fieldAKey, fieldBKey);
protected static final String fieldAKey = "fieldA";
protected static final String fieldBKey = "fieldB";
@ -47,123 +44,4 @@ public abstract class BaseMatrixStatsTestCase extends ESTestCase {
actualStats.computeStats(fieldA, fieldB);
}
static class MultiPassStats {
long count;
HashMap<String, Double> means = new HashMap<>();
HashMap<String, Double> variances = new HashMap<>();
HashMap<String, Double> skewness = new HashMap<>();
HashMap<String, Double> kurtosis = new HashMap<>();
HashMap<String, HashMap<String, Double>> covariances = new HashMap<>();
HashMap<String, HashMap<String, Double>> correlations = new HashMap<>();
@SuppressWarnings("unchecked")
void computeStats(final ArrayList<Double> fieldA, final ArrayList<Double> fieldB) {
// set count
count = fieldA.size();
double meanA = 0d;
double meanB = 0d;
// compute mean
for (int n = 0; n < count; ++n) {
// fieldA
meanA += fieldA.get(n);
meanB += fieldB.get(n);
}
means.put(fieldAKey, meanA/count);
means.put(fieldBKey, meanB/count);
// compute variance, skewness, and kurtosis
double dA;
double dB;
double skewA = 0d;
double skewB = 0d;
double kurtA = 0d;
double kurtB = 0d;
double varA = 0d;
double varB = 0d;
double cVar = 0d;
for (int n = 0; n < count; ++n) {
dA = fieldA.get(n) - means.get(fieldAKey);
varA += dA * dA;
skewA += dA * dA * dA;
kurtA += dA * dA * dA * dA;
dB = fieldB.get(n) - means.get(fieldBKey);
varB += dB * dB;
skewB += dB * dB * dB;
kurtB += dB * dB * dB * dB;
cVar += dA * dB;
}
variances.put(fieldAKey, varA / (count - 1));
final double stdA = Math.sqrt(variances.get(fieldAKey));
variances.put(fieldBKey, varB / (count - 1));
final double stdB = Math.sqrt(variances.get(fieldBKey));
skewness.put(fieldAKey, skewA / ((count - 1) * variances.get(fieldAKey) * stdA));
skewness.put(fieldBKey, skewB / ((count - 1) * variances.get(fieldBKey) * stdB));
kurtosis.put(fieldAKey, kurtA / ((count - 1) * variances.get(fieldAKey) * variances.get(fieldAKey)));
kurtosis.put(fieldBKey, kurtB / ((count - 1) * variances.get(fieldBKey) * variances.get(fieldBKey)));
// compute covariance
final HashMap<String, Double> fieldACovar = new HashMap<>(2);
fieldACovar.put(fieldAKey, 1d);
cVar /= count - 1;
fieldACovar.put(fieldBKey, cVar);
covariances.put(fieldAKey, fieldACovar);
final HashMap<String, Double> fieldBCovar = new HashMap<>(2);
fieldBCovar.put(fieldAKey, cVar);
fieldBCovar.put(fieldBKey, 1d);
covariances.put(fieldBKey, fieldBCovar);
// compute correlation
final HashMap<String, Double> fieldACorr = new HashMap<>();
fieldACorr.put(fieldAKey, 1d);
double corr = covariances.get(fieldAKey).get(fieldBKey);
corr /= stdA * stdB;
fieldACorr.put(fieldBKey, corr);
correlations.put(fieldAKey, fieldACorr);
final HashMap<String, Double> fieldBCorr = new HashMap<>();
fieldBCorr.put(fieldAKey, corr);
fieldBCorr.put(fieldBKey, 1d);
correlations.put(fieldBKey, fieldBCorr);
}
public void assertNearlyEqual(MatrixStatsResults stats) {
assertThat(count, equalTo(stats.getDocCount()));
assertThat(count, equalTo(stats.getFieldCount(fieldAKey)));
assertThat(count, equalTo(stats.getFieldCount(fieldBKey)));
// means
assertTrue(nearlyEqual(means.get(fieldAKey), stats.getMean(fieldAKey), 1e-7));
assertTrue(nearlyEqual(means.get(fieldBKey), stats.getMean(fieldBKey), 1e-7));
// variances
assertTrue(nearlyEqual(variances.get(fieldAKey), stats.getVariance(fieldAKey), 1e-7));
assertTrue(nearlyEqual(variances.get(fieldBKey), stats.getVariance(fieldBKey), 1e-7));
// skewness (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
assertTrue(nearlyEqual(skewness.get(fieldAKey), stats.getSkewness(fieldAKey), 1e-4));
assertTrue(nearlyEqual(skewness.get(fieldBKey), stats.getSkewness(fieldBKey), 1e-4));
// kurtosis (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
assertTrue(nearlyEqual(kurtosis.get(fieldAKey), stats.getKurtosis(fieldAKey), 1e-4));
assertTrue(nearlyEqual(kurtosis.get(fieldBKey), stats.getKurtosis(fieldBKey), 1e-4));
// covariances
assertTrue(nearlyEqual(covariances.get(fieldAKey).get(fieldBKey), stats.getCovariance(fieldAKey, fieldBKey), 1e-7));
assertTrue(nearlyEqual(covariances.get(fieldBKey).get(fieldAKey), stats.getCovariance(fieldBKey, fieldAKey), 1e-7));
// correlation
assertTrue(nearlyEqual(correlations.get(fieldAKey).get(fieldBKey), stats.getCorrelation(fieldAKey, fieldBKey), 1e-7));
assertTrue(nearlyEqual(correlations.get(fieldBKey).get(fieldAKey), stats.getCorrelation(fieldBKey, fieldAKey), 1e-7));
}
}
private static boolean nearlyEqual(double a, double b, double epsilon) {
final double absA = Math.abs(a);
final double absB = Math.abs(b);
final double diff = Math.abs(a - b);
if (a == b) { // shortcut, handles infinities
return true;
} else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) {
// a or b is zero or both are extremely close to it
// relative error is less meaningful here
return diff < (epsilon * Double.MIN_NORMAL);
} else { // use relative error
return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon;
}
}
}

View File

@ -0,0 +1,103 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.matrix.stats;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.InternalAggregationTestCase;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class InternalMatrixStatsTests extends InternalAggregationTestCase<InternalMatrixStats> {
@Override
protected InternalMatrixStats createTestInstance(String name, List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) {
int numFields = randomInt(128);
String[] fieldNames = new String[numFields];
double[] fieldValues = new double[numFields];
for (int i = 0; i < numFields; i++) {
fieldNames[i] = Integer.toString(i);
fieldValues[i] = randomDouble();
}
RunningStats runningStats = new RunningStats();
runningStats.add(fieldNames, fieldValues);
MatrixStatsResults matrixStatsResults = randomBoolean() ? new MatrixStatsResults(runningStats) : null;
return new InternalMatrixStats("_name", 1L, runningStats, matrixStatsResults, Collections.emptyList(), Collections.emptyMap());
}
@Override
protected Writeable.Reader<InternalMatrixStats> instanceReader() {
return InternalMatrixStats::new;
}
@Override
public void testReduceRandom() {
int numValues = 10000;
int numShards = randomIntBetween(1, 20);
int valuesPerShard = (int) Math.floor(numValues / numShards);
List<Double> aValues = new ArrayList<>();
List<Double> bValues = new ArrayList<>();
RunningStats runningStats = new RunningStats();
List<InternalAggregation> shardResults = new ArrayList<>();
int valuePerShardCounter = 0;
for (int i = 0; i < numValues; i++) {
double valueA = randomDouble();
aValues.add(valueA);
double valueB = randomDouble();
bValues.add(valueB);
runningStats.add(new String[]{"a", "b"}, new double[]{valueA, valueB});
if (++valuePerShardCounter == valuesPerShard) {
shardResults.add(new InternalMatrixStats("_name", 1L, runningStats, null, Collections.emptyList(), Collections.emptyMap()));
runningStats = new RunningStats();
valuePerShardCounter = 0;
}
}
if (valuePerShardCounter != 0) {
shardResults.add(new InternalMatrixStats("_name", 1L, runningStats, null, Collections.emptyList(), Collections.emptyMap()));
}
MultiPassStats multiPassStats = new MultiPassStats("a", "b");
multiPassStats.computeStats(aValues, bValues);
ScriptService mockScriptService = mockScriptService();
MockBigArrays bigArrays = new MockBigArrays(Settings.EMPTY, new NoneCircuitBreakerService());
InternalAggregation.ReduceContext context =
new InternalAggregation.ReduceContext(bigArrays, mockScriptService, true);
InternalMatrixStats reduced = (InternalMatrixStats) shardResults.get(0).reduce(shardResults, context);
multiPassStats.assertNearlyEqual(reduced.getResults());
}
@Override
protected void assertReduced(InternalMatrixStats reduced, List<InternalMatrixStats> inputs) {
throw new UnsupportedOperationException();
}
}

View File

@ -0,0 +1,155 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.matrix.stats;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
class MultiPassStats {
private final String fieldAKey;
private final String fieldBKey;
private long count;
private Map<String, Double> means = new HashMap<>();
private Map<String, Double> variances = new HashMap<>();
private Map<String, Double> skewness = new HashMap<>();
private Map<String, Double> kurtosis = new HashMap<>();
private Map<String, HashMap<String, Double>> covariances = new HashMap<>();
private Map<String, HashMap<String, Double>> correlations = new HashMap<>();
MultiPassStats(String fieldAName, String fieldBName) {
this.fieldAKey = fieldAName;
this.fieldBKey = fieldBName;
}
@SuppressWarnings("unchecked")
void computeStats(final List<Double> fieldA, final List<Double> fieldB) {
// set count
count = fieldA.size();
double meanA = 0d;
double meanB = 0d;
// compute mean
for (int n = 0; n < count; ++n) {
// fieldA
meanA += fieldA.get(n);
meanB += fieldB.get(n);
}
means.put(fieldAKey, meanA / count);
means.put(fieldBKey, meanB / count);
// compute variance, skewness, and kurtosis
double dA;
double dB;
double skewA = 0d;
double skewB = 0d;
double kurtA = 0d;
double kurtB = 0d;
double varA = 0d;
double varB = 0d;
double cVar = 0d;
for (int n = 0; n < count; ++n) {
dA = fieldA.get(n) - means.get(fieldAKey);
varA += dA * dA;
skewA += dA * dA * dA;
kurtA += dA * dA * dA * dA;
dB = fieldB.get(n) - means.get(fieldBKey);
varB += dB * dB;
skewB += dB * dB * dB;
kurtB += dB * dB * dB * dB;
cVar += dA * dB;
}
variances.put(fieldAKey, varA / (count - 1));
final double stdA = Math.sqrt(variances.get(fieldAKey));
variances.put(fieldBKey, varB / (count - 1));
final double stdB = Math.sqrt(variances.get(fieldBKey));
skewness.put(fieldAKey, skewA / ((count - 1) * variances.get(fieldAKey) * stdA));
skewness.put(fieldBKey, skewB / ((count - 1) * variances.get(fieldBKey) * stdB));
kurtosis.put(fieldAKey, kurtA / ((count - 1) * variances.get(fieldAKey) * variances.get(fieldAKey)));
kurtosis.put(fieldBKey, kurtB / ((count - 1) * variances.get(fieldBKey) * variances.get(fieldBKey)));
// compute covariance
final HashMap<String, Double> fieldACovar = new HashMap<>(2);
fieldACovar.put(fieldAKey, 1d);
cVar /= count - 1;
fieldACovar.put(fieldBKey, cVar);
covariances.put(fieldAKey, fieldACovar);
final HashMap<String, Double> fieldBCovar = new HashMap<>(2);
fieldBCovar.put(fieldAKey, cVar);
fieldBCovar.put(fieldBKey, 1d);
covariances.put(fieldBKey, fieldBCovar);
// compute correlation
final HashMap<String, Double> fieldACorr = new HashMap<>();
fieldACorr.put(fieldAKey, 1d);
double corr = covariances.get(fieldAKey).get(fieldBKey);
corr /= stdA * stdB;
fieldACorr.put(fieldBKey, corr);
correlations.put(fieldAKey, fieldACorr);
final HashMap<String, Double> fieldBCorr = new HashMap<>();
fieldBCorr.put(fieldAKey, corr);
fieldBCorr.put(fieldBKey, 1d);
correlations.put(fieldBKey, fieldBCorr);
}
void assertNearlyEqual(MatrixStatsResults stats) {
assertEquals(count, stats.getDocCount());
assertEquals(count, stats.getFieldCount(fieldAKey));
assertEquals(count, stats.getFieldCount(fieldBKey));
// means
assertTrue(nearlyEqual(means.get(fieldAKey), stats.getMean(fieldAKey), 1e-7));
assertTrue(nearlyEqual(means.get(fieldBKey), stats.getMean(fieldBKey), 1e-7));
// variances
assertTrue(nearlyEqual(variances.get(fieldAKey), stats.getVariance(fieldAKey), 1e-7));
assertTrue(nearlyEqual(variances.get(fieldBKey), stats.getVariance(fieldBKey), 1e-7));
// skewness (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
assertTrue(nearlyEqual(skewness.get(fieldAKey), stats.getSkewness(fieldAKey), 1e-4));
assertTrue(nearlyEqual(skewness.get(fieldBKey), stats.getSkewness(fieldBKey), 1e-4));
// kurtosis (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
assertTrue(nearlyEqual(kurtosis.get(fieldAKey), stats.getKurtosis(fieldAKey), 1e-4));
assertTrue(nearlyEqual(kurtosis.get(fieldBKey), stats.getKurtosis(fieldBKey), 1e-4));
// covariances
assertTrue(nearlyEqual(covariances.get(fieldAKey).get(fieldBKey),stats.getCovariance(fieldAKey, fieldBKey), 1e-7));
assertTrue(nearlyEqual(covariances.get(fieldBKey).get(fieldAKey),stats.getCovariance(fieldBKey, fieldAKey), 1e-7));
// correlation
assertTrue(nearlyEqual(correlations.get(fieldAKey).get(fieldBKey), stats.getCorrelation(fieldAKey, fieldBKey), 1e-7));
assertTrue(nearlyEqual(correlations.get(fieldBKey).get(fieldAKey), stats.getCorrelation(fieldBKey, fieldAKey), 1e-7));
}
private static boolean nearlyEqual(double a, double b, double epsilon) {
final double absA = Math.abs(a);
final double absB = Math.abs(b);
final double diff = Math.abs(a - b);
if (a == b) { // shortcut, handles infinities
return true;
} else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) {
// a or b is zero or both are extremely close to it
// relative error is less meaningful here
return diff < (epsilon * Double.MIN_NORMAL);
} else { // use relative error
return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon;
}
}
}

View File

@ -17,7 +17,7 @@
* under the License.
*/
package org.elasticsearch.search.aggregations;
package org.elasticsearch.test;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
@ -26,8 +26,8 @@ import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.util.ArrayList;
import java.util.Collections;