Added unit tests for InternalMatrixStats.
Also moved InternalAggregationTestCase to test-framework module in order to make use of it from other modules than core. Relates to #22278
This commit is contained in:
parent
b24326271e
commit
51c74ce547
|
@ -20,11 +20,11 @@
|
|||
package org.elasticsearch.search.aggregations.bucket;
|
||||
|
||||
import org.elasticsearch.search.aggregations.InternalAggregation;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregations;
|
||||
import org.elasticsearch.search.aggregations.metrics.max.InternalMax;
|
||||
import org.elasticsearch.search.aggregations.metrics.min.InternalMin;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
|
|
@ -21,9 +21,9 @@ package org.elasticsearch.search.aggregations.bucket.geogrid;
|
|||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.elasticsearch.common.geo.GeoHashUtils;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregations;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
|
|
@ -19,13 +19,11 @@
|
|||
|
||||
package org.elasticsearch.search.aggregations.bucket.histogram;
|
||||
|
||||
import org.apache.lucene.util.TestUtil;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregations;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
import org.joda.time.DateTime;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
|
|
@ -22,9 +22,9 @@ package org.elasticsearch.search.aggregations.bucket.histogram;
|
|||
import org.apache.lucene.util.TestUtil;
|
||||
import org.elasticsearch.common.io.stream.Writeable.Reader;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregations;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
|
|
@ -22,9 +22,9 @@ import org.apache.lucene.util.BytesRef;
|
|||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregations;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
|
||||
package org.elasticsearch.search.aggregations.bucket.significant;
|
||||
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
|
|
|
@ -19,14 +19,14 @@
|
|||
|
||||
package org.elasticsearch.search.aggregations.bucket.terms;
|
||||
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
|
|
@ -21,9 +21,9 @@ package org.elasticsearch.search.aggregations.metrics;
|
|||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.metrics.stats.extended.InternalExtendedStats;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.Collections;
|
||||
|
|
|
@ -20,9 +20,9 @@
|
|||
package org.elasticsearch.search.aggregations.metrics;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable.Reader;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.metrics.max.InternalMax;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
|
@ -20,9 +20,9 @@ package org.elasticsearch.search.aggregations.metrics;
|
|||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.metrics.stats.InternalStats;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
|
|
@ -20,9 +20,8 @@
|
|||
package org.elasticsearch.search.aggregations.metrics.avg;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable.Reader;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
|
@ -24,8 +24,8 @@ import org.elasticsearch.common.lease.Releasables;
|
|||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.MockBigArrays;
|
||||
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
package org.elasticsearch.search.aggregations.metrics.geobounds;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
|
|
@ -21,8 +21,8 @@ package org.elasticsearch.search.aggregations.metrics.geocentroid;
|
|||
import org.apache.lucene.geo.GeoEncodingUtils;
|
||||
import org.elasticsearch.common.geo.GeoPoint;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
import org.elasticsearch.test.geo.RandomGeoGenerator;
|
||||
|
||||
import java.util.Collections;
|
||||
|
|
|
@ -20,9 +20,8 @@
|
|||
package org.elasticsearch.search.aggregations.metrics.min;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable.Reader;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
|
@ -21,8 +21,8 @@ package org.elasticsearch.search.aggregations.metrics.percentiles;
|
|||
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregation;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.List;
|
||||
|
|
|
@ -22,9 +22,8 @@ package org.elasticsearch.search.aggregations.metrics.percentiles.hdr;
|
|||
import org.HdrHistogram.DoubleHistogram;
|
||||
import org.elasticsearch.common.io.stream.Writeable.Reader;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.metrics.percentiles.hdr.InternalHDRPercentileRanks;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
|
@ -21,8 +21,8 @@ package org.elasticsearch.search.aggregations.metrics.percentiles.tdigest;
|
|||
|
||||
import org.elasticsearch.common.io.stream.Writeable.Reader;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
|
@ -30,8 +30,8 @@ import org.elasticsearch.script.ScriptEngineRegistry;
|
|||
import org.elasticsearch.script.ScriptService;
|
||||
import org.elasticsearch.script.ScriptSettings;
|
||||
import org.elasticsearch.script.ScriptType;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
|
|
@ -20,8 +20,8 @@ package org.elasticsearch.search.aggregations.metrics.sum;
|
|||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
|
@ -30,11 +30,11 @@ import org.apache.lucene.util.BytesRef;
|
|||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.Writeable.Reader;
|
||||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.search.SearchHitField;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchHitField;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
package org.elasticsearch.search.aggregations.metrics.valuecount;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
|
@ -21,7 +21,7 @@ package org.elasticsearch.search.aggregations.pipeline;
|
|||
|
||||
import org.elasticsearch.common.io.stream.Writeable.Reader;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
|
|
@ -21,9 +21,9 @@ package org.elasticsearch.search.aggregations.pipeline.bucketmetrics.percentile;
|
|||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.metrics.percentiles.Percentile;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
|
|
|
@ -21,8 +21,8 @@ package org.elasticsearch.search.aggregations.pipeline.derivative;
|
|||
|
||||
import org.elasticsearch.common.io.stream.Writeable.Reader;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
|
|
@ -28,6 +28,7 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static java.util.Collections.emptyMap;
|
||||
|
||||
|
@ -41,7 +42,7 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
|
|||
private final MatrixStatsResults results;
|
||||
|
||||
/** per shard ctor */
|
||||
protected InternalMatrixStats(String name, long count, RunningStats multiFieldStatsResults, MatrixStatsResults results,
|
||||
InternalMatrixStats(String name, long count, RunningStats multiFieldStatsResults, MatrixStatsResults results,
|
||||
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
|
||||
super(name, pipelineAggregators, metaData);
|
||||
assert count >= 0;
|
||||
|
@ -138,6 +139,10 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
|
|||
return results.getCorrelation(fieldX, fieldY);
|
||||
}
|
||||
|
||||
MatrixStatsResults getResults() {
|
||||
return results;
|
||||
}
|
||||
|
||||
static class Fields {
|
||||
public static final String FIELDS = "fields";
|
||||
public static final String NAME = "name";
|
||||
|
@ -238,4 +243,16 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
|
|||
|
||||
return new InternalMatrixStats(name, results.getDocCount(), runningStats, results, pipelineAggregators(), getMetaData());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int doHashCode() {
|
||||
return Objects.hash(stats, results);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean doEquals(Object obj) {
|
||||
InternalMatrixStats other = (InternalMatrixStats) obj;
|
||||
return Objects.equals(this.stats, other.stats) &&
|
||||
Objects.equals(this.results, other.results);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import java.io.IOException;
|
|||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Descriptive stats gathered per shard. Coordinating node computes final pearson product coefficient
|
||||
|
@ -228,4 +229,18 @@ class MatrixStatsResults implements Writeable {
|
|||
correlation.put(rowName, corRow);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
MatrixStatsResults that = (MatrixStatsResults) o;
|
||||
return Objects.equals(results, that.results) &&
|
||||
Objects.equals(correlation, that.correlation);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(results, correlation);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Descriptive stats gathered per shard. Coordinating node computes final correlation and covariance stats
|
||||
|
@ -53,11 +54,11 @@ public class RunningStats implements Writeable, Cloneable {
|
|||
/** covariance values */
|
||||
protected HashMap<String, HashMap<String, Double>> covariances;
|
||||
|
||||
public RunningStats() {
|
||||
RunningStats() {
|
||||
init();
|
||||
}
|
||||
|
||||
public RunningStats(final String[] fieldNames, final double[] fieldVals) {
|
||||
RunningStats(final String[] fieldNames, final double[] fieldVals) {
|
||||
if (fieldVals != null && fieldVals.length > 0) {
|
||||
init();
|
||||
this.add(fieldNames, fieldVals);
|
||||
|
@ -309,4 +310,24 @@ public class RunningStats implements Writeable, Cloneable {
|
|||
throw new ElasticsearchException("Error trying to create a copy of RunningStats");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
RunningStats that = (RunningStats) o;
|
||||
return docCount == that.docCount &&
|
||||
Objects.equals(fieldSum, that.fieldSum) &&
|
||||
Objects.equals(counts, that.counts) &&
|
||||
Objects.equals(means, that.means) &&
|
||||
Objects.equals(variances, that.variances) &&
|
||||
Objects.equals(skewness, that.skewness) &&
|
||||
Objects.equals(kurtosis, that.kurtosis) &&
|
||||
Objects.equals(covariances, that.covariances);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(docCount, fieldSum, counts, means, variances, skewness, kurtosis, covariances);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,15 +22,12 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public abstract class BaseMatrixStatsTestCase extends ESTestCase {
|
||||
protected final int numObs = atLeast(10000);
|
||||
protected final ArrayList<Double> fieldA = new ArrayList<>(numObs);
|
||||
protected final ArrayList<Double> fieldB = new ArrayList<>(numObs);
|
||||
protected final MultiPassStats actualStats = new MultiPassStats();
|
||||
protected final MultiPassStats actualStats = new MultiPassStats(fieldAKey, fieldBKey);
|
||||
protected static final String fieldAKey = "fieldA";
|
||||
protected static final String fieldBKey = "fieldB";
|
||||
|
||||
|
@ -47,123 +44,4 @@ public abstract class BaseMatrixStatsTestCase extends ESTestCase {
|
|||
actualStats.computeStats(fieldA, fieldB);
|
||||
}
|
||||
|
||||
static class MultiPassStats {
|
||||
long count;
|
||||
HashMap<String, Double> means = new HashMap<>();
|
||||
HashMap<String, Double> variances = new HashMap<>();
|
||||
HashMap<String, Double> skewness = new HashMap<>();
|
||||
HashMap<String, Double> kurtosis = new HashMap<>();
|
||||
HashMap<String, HashMap<String, Double>> covariances = new HashMap<>();
|
||||
HashMap<String, HashMap<String, Double>> correlations = new HashMap<>();
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
void computeStats(final ArrayList<Double> fieldA, final ArrayList<Double> fieldB) {
|
||||
// set count
|
||||
count = fieldA.size();
|
||||
double meanA = 0d;
|
||||
double meanB = 0d;
|
||||
|
||||
// compute mean
|
||||
for (int n = 0; n < count; ++n) {
|
||||
// fieldA
|
||||
meanA += fieldA.get(n);
|
||||
meanB += fieldB.get(n);
|
||||
}
|
||||
means.put(fieldAKey, meanA/count);
|
||||
means.put(fieldBKey, meanB/count);
|
||||
|
||||
// compute variance, skewness, and kurtosis
|
||||
double dA;
|
||||
double dB;
|
||||
double skewA = 0d;
|
||||
double skewB = 0d;
|
||||
double kurtA = 0d;
|
||||
double kurtB = 0d;
|
||||
double varA = 0d;
|
||||
double varB = 0d;
|
||||
double cVar = 0d;
|
||||
for (int n = 0; n < count; ++n) {
|
||||
dA = fieldA.get(n) - means.get(fieldAKey);
|
||||
varA += dA * dA;
|
||||
skewA += dA * dA * dA;
|
||||
kurtA += dA * dA * dA * dA;
|
||||
dB = fieldB.get(n) - means.get(fieldBKey);
|
||||
varB += dB * dB;
|
||||
skewB += dB * dB * dB;
|
||||
kurtB += dB * dB * dB * dB;
|
||||
cVar += dA * dB;
|
||||
}
|
||||
variances.put(fieldAKey, varA / (count - 1));
|
||||
final double stdA = Math.sqrt(variances.get(fieldAKey));
|
||||
variances.put(fieldBKey, varB / (count - 1));
|
||||
final double stdB = Math.sqrt(variances.get(fieldBKey));
|
||||
skewness.put(fieldAKey, skewA / ((count - 1) * variances.get(fieldAKey) * stdA));
|
||||
skewness.put(fieldBKey, skewB / ((count - 1) * variances.get(fieldBKey) * stdB));
|
||||
kurtosis.put(fieldAKey, kurtA / ((count - 1) * variances.get(fieldAKey) * variances.get(fieldAKey)));
|
||||
kurtosis.put(fieldBKey, kurtB / ((count - 1) * variances.get(fieldBKey) * variances.get(fieldBKey)));
|
||||
|
||||
// compute covariance
|
||||
final HashMap<String, Double> fieldACovar = new HashMap<>(2);
|
||||
fieldACovar.put(fieldAKey, 1d);
|
||||
cVar /= count - 1;
|
||||
fieldACovar.put(fieldBKey, cVar);
|
||||
covariances.put(fieldAKey, fieldACovar);
|
||||
final HashMap<String, Double> fieldBCovar = new HashMap<>(2);
|
||||
fieldBCovar.put(fieldAKey, cVar);
|
||||
fieldBCovar.put(fieldBKey, 1d);
|
||||
covariances.put(fieldBKey, fieldBCovar);
|
||||
|
||||
// compute correlation
|
||||
final HashMap<String, Double> fieldACorr = new HashMap<>();
|
||||
fieldACorr.put(fieldAKey, 1d);
|
||||
double corr = covariances.get(fieldAKey).get(fieldBKey);
|
||||
corr /= stdA * stdB;
|
||||
fieldACorr.put(fieldBKey, corr);
|
||||
correlations.put(fieldAKey, fieldACorr);
|
||||
final HashMap<String, Double> fieldBCorr = new HashMap<>();
|
||||
fieldBCorr.put(fieldAKey, corr);
|
||||
fieldBCorr.put(fieldBKey, 1d);
|
||||
correlations.put(fieldBKey, fieldBCorr);
|
||||
}
|
||||
|
||||
public void assertNearlyEqual(MatrixStatsResults stats) {
|
||||
assertThat(count, equalTo(stats.getDocCount()));
|
||||
assertThat(count, equalTo(stats.getFieldCount(fieldAKey)));
|
||||
assertThat(count, equalTo(stats.getFieldCount(fieldBKey)));
|
||||
// means
|
||||
assertTrue(nearlyEqual(means.get(fieldAKey), stats.getMean(fieldAKey), 1e-7));
|
||||
assertTrue(nearlyEqual(means.get(fieldBKey), stats.getMean(fieldBKey), 1e-7));
|
||||
// variances
|
||||
assertTrue(nearlyEqual(variances.get(fieldAKey), stats.getVariance(fieldAKey), 1e-7));
|
||||
assertTrue(nearlyEqual(variances.get(fieldBKey), stats.getVariance(fieldBKey), 1e-7));
|
||||
// skewness (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
|
||||
assertTrue(nearlyEqual(skewness.get(fieldAKey), stats.getSkewness(fieldAKey), 1e-4));
|
||||
assertTrue(nearlyEqual(skewness.get(fieldBKey), stats.getSkewness(fieldBKey), 1e-4));
|
||||
// kurtosis (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
|
||||
assertTrue(nearlyEqual(kurtosis.get(fieldAKey), stats.getKurtosis(fieldAKey), 1e-4));
|
||||
assertTrue(nearlyEqual(kurtosis.get(fieldBKey), stats.getKurtosis(fieldBKey), 1e-4));
|
||||
// covariances
|
||||
assertTrue(nearlyEqual(covariances.get(fieldAKey).get(fieldBKey), stats.getCovariance(fieldAKey, fieldBKey), 1e-7));
|
||||
assertTrue(nearlyEqual(covariances.get(fieldBKey).get(fieldAKey), stats.getCovariance(fieldBKey, fieldAKey), 1e-7));
|
||||
// correlation
|
||||
assertTrue(nearlyEqual(correlations.get(fieldAKey).get(fieldBKey), stats.getCorrelation(fieldAKey, fieldBKey), 1e-7));
|
||||
assertTrue(nearlyEqual(correlations.get(fieldBKey).get(fieldAKey), stats.getCorrelation(fieldBKey, fieldAKey), 1e-7));
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean nearlyEqual(double a, double b, double epsilon) {
|
||||
final double absA = Math.abs(a);
|
||||
final double absB = Math.abs(b);
|
||||
final double diff = Math.abs(a - b);
|
||||
|
||||
if (a == b) { // shortcut, handles infinities
|
||||
return true;
|
||||
} else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) {
|
||||
// a or b is zero or both are extremely close to it
|
||||
// relative error is less meaningful here
|
||||
return diff < (epsilon * Double.MIN_NORMAL);
|
||||
} else { // use relative error
|
||||
return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.search.aggregations.matrix.stats;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.MockBigArrays;
|
||||
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
|
||||
import org.elasticsearch.script.ScriptService;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregation;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class InternalMatrixStatsTests extends InternalAggregationTestCase<InternalMatrixStats> {
|
||||
|
||||
@Override
|
||||
protected InternalMatrixStats createTestInstance(String name, List<PipelineAggregator> pipelineAggregators,
|
||||
Map<String, Object> metaData) {
|
||||
int numFields = randomInt(128);
|
||||
String[] fieldNames = new String[numFields];
|
||||
double[] fieldValues = new double[numFields];
|
||||
for (int i = 0; i < numFields; i++) {
|
||||
fieldNames[i] = Integer.toString(i);
|
||||
fieldValues[i] = randomDouble();
|
||||
}
|
||||
RunningStats runningStats = new RunningStats();
|
||||
runningStats.add(fieldNames, fieldValues);
|
||||
MatrixStatsResults matrixStatsResults = randomBoolean() ? new MatrixStatsResults(runningStats) : null;
|
||||
return new InternalMatrixStats("_name", 1L, runningStats, matrixStatsResults, Collections.emptyList(), Collections.emptyMap());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<InternalMatrixStats> instanceReader() {
|
||||
return InternalMatrixStats::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void testReduceRandom() {
|
||||
int numValues = 10000;
|
||||
int numShards = randomIntBetween(1, 20);
|
||||
int valuesPerShard = (int) Math.floor(numValues / numShards);
|
||||
|
||||
List<Double> aValues = new ArrayList<>();
|
||||
List<Double> bValues = new ArrayList<>();
|
||||
|
||||
RunningStats runningStats = new RunningStats();
|
||||
List<InternalAggregation> shardResults = new ArrayList<>();
|
||||
|
||||
int valuePerShardCounter = 0;
|
||||
for (int i = 0; i < numValues; i++) {
|
||||
double valueA = randomDouble();
|
||||
aValues.add(valueA);
|
||||
double valueB = randomDouble();
|
||||
bValues.add(valueB);
|
||||
|
||||
runningStats.add(new String[]{"a", "b"}, new double[]{valueA, valueB});
|
||||
if (++valuePerShardCounter == valuesPerShard) {
|
||||
shardResults.add(new InternalMatrixStats("_name", 1L, runningStats, null, Collections.emptyList(), Collections.emptyMap()));
|
||||
runningStats = new RunningStats();
|
||||
valuePerShardCounter = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (valuePerShardCounter != 0) {
|
||||
shardResults.add(new InternalMatrixStats("_name", 1L, runningStats, null, Collections.emptyList(), Collections.emptyMap()));
|
||||
}
|
||||
MultiPassStats multiPassStats = new MultiPassStats("a", "b");
|
||||
multiPassStats.computeStats(aValues, bValues);
|
||||
|
||||
ScriptService mockScriptService = mockScriptService();
|
||||
MockBigArrays bigArrays = new MockBigArrays(Settings.EMPTY, new NoneCircuitBreakerService());
|
||||
InternalAggregation.ReduceContext context =
|
||||
new InternalAggregation.ReduceContext(bigArrays, mockScriptService, true);
|
||||
InternalMatrixStats reduced = (InternalMatrixStats) shardResults.get(0).reduce(shardResults, context);
|
||||
multiPassStats.assertNearlyEqual(reduced.getResults());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertReduced(InternalMatrixStats reduced, List<InternalMatrixStats> inputs) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,155 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.search.aggregations.matrix.stats;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
class MultiPassStats {
|
||||
|
||||
private final String fieldAKey;
|
||||
private final String fieldBKey;
|
||||
|
||||
private long count;
|
||||
private Map<String, Double> means = new HashMap<>();
|
||||
private Map<String, Double> variances = new HashMap<>();
|
||||
private Map<String, Double> skewness = new HashMap<>();
|
||||
private Map<String, Double> kurtosis = new HashMap<>();
|
||||
private Map<String, HashMap<String, Double>> covariances = new HashMap<>();
|
||||
private Map<String, HashMap<String, Double>> correlations = new HashMap<>();
|
||||
|
||||
MultiPassStats(String fieldAName, String fieldBName) {
|
||||
this.fieldAKey = fieldAName;
|
||||
this.fieldBKey = fieldBName;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
void computeStats(final List<Double> fieldA, final List<Double> fieldB) {
|
||||
// set count
|
||||
count = fieldA.size();
|
||||
double meanA = 0d;
|
||||
double meanB = 0d;
|
||||
|
||||
// compute mean
|
||||
for (int n = 0; n < count; ++n) {
|
||||
// fieldA
|
||||
meanA += fieldA.get(n);
|
||||
meanB += fieldB.get(n);
|
||||
}
|
||||
means.put(fieldAKey, meanA / count);
|
||||
means.put(fieldBKey, meanB / count);
|
||||
|
||||
// compute variance, skewness, and kurtosis
|
||||
double dA;
|
||||
double dB;
|
||||
double skewA = 0d;
|
||||
double skewB = 0d;
|
||||
double kurtA = 0d;
|
||||
double kurtB = 0d;
|
||||
double varA = 0d;
|
||||
double varB = 0d;
|
||||
double cVar = 0d;
|
||||
for (int n = 0; n < count; ++n) {
|
||||
dA = fieldA.get(n) - means.get(fieldAKey);
|
||||
varA += dA * dA;
|
||||
skewA += dA * dA * dA;
|
||||
kurtA += dA * dA * dA * dA;
|
||||
dB = fieldB.get(n) - means.get(fieldBKey);
|
||||
varB += dB * dB;
|
||||
skewB += dB * dB * dB;
|
||||
kurtB += dB * dB * dB * dB;
|
||||
cVar += dA * dB;
|
||||
}
|
||||
variances.put(fieldAKey, varA / (count - 1));
|
||||
final double stdA = Math.sqrt(variances.get(fieldAKey));
|
||||
variances.put(fieldBKey, varB / (count - 1));
|
||||
final double stdB = Math.sqrt(variances.get(fieldBKey));
|
||||
skewness.put(fieldAKey, skewA / ((count - 1) * variances.get(fieldAKey) * stdA));
|
||||
skewness.put(fieldBKey, skewB / ((count - 1) * variances.get(fieldBKey) * stdB));
|
||||
kurtosis.put(fieldAKey, kurtA / ((count - 1) * variances.get(fieldAKey) * variances.get(fieldAKey)));
|
||||
kurtosis.put(fieldBKey, kurtB / ((count - 1) * variances.get(fieldBKey) * variances.get(fieldBKey)));
|
||||
|
||||
// compute covariance
|
||||
final HashMap<String, Double> fieldACovar = new HashMap<>(2);
|
||||
fieldACovar.put(fieldAKey, 1d);
|
||||
cVar /= count - 1;
|
||||
fieldACovar.put(fieldBKey, cVar);
|
||||
covariances.put(fieldAKey, fieldACovar);
|
||||
final HashMap<String, Double> fieldBCovar = new HashMap<>(2);
|
||||
fieldBCovar.put(fieldAKey, cVar);
|
||||
fieldBCovar.put(fieldBKey, 1d);
|
||||
covariances.put(fieldBKey, fieldBCovar);
|
||||
|
||||
// compute correlation
|
||||
final HashMap<String, Double> fieldACorr = new HashMap<>();
|
||||
fieldACorr.put(fieldAKey, 1d);
|
||||
double corr = covariances.get(fieldAKey).get(fieldBKey);
|
||||
corr /= stdA * stdB;
|
||||
fieldACorr.put(fieldBKey, corr);
|
||||
correlations.put(fieldAKey, fieldACorr);
|
||||
final HashMap<String, Double> fieldBCorr = new HashMap<>();
|
||||
fieldBCorr.put(fieldAKey, corr);
|
||||
fieldBCorr.put(fieldBKey, 1d);
|
||||
correlations.put(fieldBKey, fieldBCorr);
|
||||
}
|
||||
|
||||
void assertNearlyEqual(MatrixStatsResults stats) {
|
||||
assertEquals(count, stats.getDocCount());
|
||||
assertEquals(count, stats.getFieldCount(fieldAKey));
|
||||
assertEquals(count, stats.getFieldCount(fieldBKey));
|
||||
// means
|
||||
assertTrue(nearlyEqual(means.get(fieldAKey), stats.getMean(fieldAKey), 1e-7));
|
||||
assertTrue(nearlyEqual(means.get(fieldBKey), stats.getMean(fieldBKey), 1e-7));
|
||||
// variances
|
||||
assertTrue(nearlyEqual(variances.get(fieldAKey), stats.getVariance(fieldAKey), 1e-7));
|
||||
assertTrue(nearlyEqual(variances.get(fieldBKey), stats.getVariance(fieldBKey), 1e-7));
|
||||
// skewness (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
|
||||
assertTrue(nearlyEqual(skewness.get(fieldAKey), stats.getSkewness(fieldAKey), 1e-4));
|
||||
assertTrue(nearlyEqual(skewness.get(fieldBKey), stats.getSkewness(fieldBKey), 1e-4));
|
||||
// kurtosis (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
|
||||
assertTrue(nearlyEqual(kurtosis.get(fieldAKey), stats.getKurtosis(fieldAKey), 1e-4));
|
||||
assertTrue(nearlyEqual(kurtosis.get(fieldBKey), stats.getKurtosis(fieldBKey), 1e-4));
|
||||
// covariances
|
||||
assertTrue(nearlyEqual(covariances.get(fieldAKey).get(fieldBKey),stats.getCovariance(fieldAKey, fieldBKey), 1e-7));
|
||||
assertTrue(nearlyEqual(covariances.get(fieldBKey).get(fieldAKey),stats.getCovariance(fieldBKey, fieldAKey), 1e-7));
|
||||
// correlation
|
||||
assertTrue(nearlyEqual(correlations.get(fieldAKey).get(fieldBKey), stats.getCorrelation(fieldAKey, fieldBKey), 1e-7));
|
||||
assertTrue(nearlyEqual(correlations.get(fieldBKey).get(fieldAKey), stats.getCorrelation(fieldBKey, fieldAKey), 1e-7));
|
||||
}
|
||||
|
||||
private static boolean nearlyEqual(double a, double b, double epsilon) {
|
||||
final double absA = Math.abs(a);
|
||||
final double absB = Math.abs(b);
|
||||
final double diff = Math.abs(a - b);
|
||||
|
||||
if (a == b) { // shortcut, handles infinities
|
||||
return true;
|
||||
} else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) {
|
||||
// a or b is zero or both are extremely close to it
|
||||
// relative error is less meaningful here
|
||||
return diff < (epsilon * Double.MIN_NORMAL);
|
||||
} else { // use relative error
|
||||
return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,7 +17,7 @@
|
|||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.search.aggregations;
|
||||
package org.elasticsearch.test;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
|
@ -26,8 +26,8 @@ import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
|
|||
import org.elasticsearch.script.ScriptService;
|
||||
import org.elasticsearch.search.DocValueFormat;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.search.aggregations.InternalAggregation;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
Loading…
Reference in New Issue