mirror of https://github.com/apache/lucene.git
LUCENE-7590: add sum, variance and stdev stats to NumericDVStats
This commit is contained in:
parent
e4f31fab2f
commit
295cab7216
|
@ -67,6 +67,9 @@ New features
|
|||
|
||||
* LUCENE-7466: Added AxiomaticSimilarity. (Peilin Yang via Tommaso Teofili)
|
||||
|
||||
* LUCENE-7590: Added DocValuesStatsCollector to compute statistics on DocValues
|
||||
fields. (Shai Erera)
|
||||
|
||||
Bug Fixes
|
||||
|
||||
* LUCENE-7547: JapaneseTokenizerFactory was failing to close the
|
||||
|
|
|
@ -98,6 +98,7 @@ public abstract class DocValuesStats<T> {
|
|||
public static abstract class NumericDocValuesStats<T extends Number> extends DocValuesStats<T> {
|
||||
|
||||
protected double mean = 0.0;
|
||||
protected double variance = 0.0;
|
||||
|
||||
protected NumericDocValues ndv;
|
||||
|
||||
|
@ -116,15 +117,32 @@ public abstract class DocValuesStats<T> {
|
|||
return ndv.advanceExact(doc);
|
||||
}
|
||||
|
||||
/** The mean of all values of the field. Undefined when {@link #count} is zero. */
|
||||
/** The mean of all values of the field. */
|
||||
public final double mean() {
|
||||
return mean;
|
||||
}
|
||||
|
||||
/** Returns the variance of all values of the field. */
|
||||
public final double variance() {
|
||||
int count = count();
|
||||
return count > 0 ? variance / count : 0;
|
||||
}
|
||||
|
||||
/** Returns the stdev of all values of the field. */
|
||||
public final double stdev() {
|
||||
return Math.sqrt(variance());
|
||||
}
|
||||
|
||||
/** Returns the sum of values of the field. Note that if the values are large, the {@code sum} might overflow. */
|
||||
public abstract T sum();
|
||||
}
|
||||
|
||||
/** Holds DocValues statistics for a numeric field storing {@code long} values. */
|
||||
public static final class LongDocValuesStats extends NumericDocValuesStats<Long> {
|
||||
|
||||
// To avoid boxing 'long' to 'Long' while the sum is computed, declare it as private variable.
|
||||
private long sum = 0;
|
||||
|
||||
public LongDocValuesStats(String field) {
|
||||
super(field, Long.MAX_VALUE, Long.MIN_VALUE);
|
||||
}
|
||||
|
@ -138,13 +156,24 @@ public abstract class DocValuesStats<T> {
|
|||
if (val < min) {
|
||||
min = val;
|
||||
}
|
||||
sum += val;
|
||||
double oldMean = mean;
|
||||
mean += (val - mean) / count;
|
||||
variance += (val - mean) * (val - oldMean);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long sum() {
|
||||
return sum;
|
||||
}
|
||||
}
|
||||
|
||||
/** Holds DocValues statistics for a numeric field storing {@code double} values. */
|
||||
public static final class DoubleDocValuesStats extends NumericDocValuesStats<Double> {
|
||||
|
||||
// To avoid boxing 'double' to 'Double' while the sum is computed, declare it as private variable.
|
||||
private double sum = 0;
|
||||
|
||||
public DoubleDocValuesStats(String field) {
|
||||
super(field, Double.MAX_VALUE, Double.MIN_VALUE);
|
||||
}
|
||||
|
@ -158,7 +187,15 @@ public abstract class DocValuesStats<T> {
|
|||
if (Double.compare(val, min) < 0) {
|
||||
min = val;
|
||||
}
|
||||
sum += val;
|
||||
double oldMean = mean;
|
||||
mean += (val - mean) / count;
|
||||
variance += (val - mean) * (val - oldMean);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double sum() {
|
||||
return sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.apache.lucene.search;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.DoubleSummaryStatistics;
|
||||
import java.util.LongSummaryStatistics;
|
||||
import java.util.stream.DoubleStream;
|
||||
import java.util.stream.LongStream;
|
||||
|
||||
|
@ -57,7 +59,33 @@ public class TestDocValuesStatsCollector extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testRandomDocsWithLongValues() throws IOException {
|
||||
public void testOneDoc() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter indexWriter = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
String field = "numeric";
|
||||
Document doc = new Document();
|
||||
doc.add(new NumericDocValuesField(field, 1));
|
||||
doc.add(new StringField("id", "doc1", Store.NO));
|
||||
indexWriter.addDocument(doc);
|
||||
|
||||
try (DirectoryReader reader = DirectoryReader.open(indexWriter)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
LongDocValuesStats stats = new LongDocValuesStats(field);
|
||||
searcher.search(new MatchAllDocsQuery(), new DocValuesStatsCollector(stats));
|
||||
|
||||
assertEquals(1, stats.count());
|
||||
assertEquals(0, stats.missing());
|
||||
assertEquals(1, stats.max().longValue());
|
||||
assertEquals(1, stats.min().longValue());
|
||||
assertEquals(1, stats.sum().longValue());
|
||||
assertEquals(1, stats.mean(), 0.0001);
|
||||
assertEquals(0, stats.variance(), 0.0001);
|
||||
assertEquals(0, stats.stdev(), 0.0001);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testDocsWithLongValues() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter indexWriter = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
String field = "numeric";
|
||||
|
@ -94,15 +122,20 @@ public class TestDocValuesStatsCollector extends LuceneTestCase {
|
|||
assertEquals(expCount, stats.count());
|
||||
assertEquals(getZeroValues(docValues).count() - reader.numDeletedDocs(), stats.missing());
|
||||
if (stats.count() > 0) {
|
||||
assertEquals(getPositiveValues(docValues).max().getAsLong(), stats.max().longValue());
|
||||
assertEquals(getPositiveValues(docValues).min().getAsLong(), stats.min().longValue());
|
||||
assertEquals(getPositiveValues(docValues).average().getAsDouble(), stats.mean(), 0.00001);
|
||||
LongSummaryStatistics sumStats = getPositiveValues(docValues).summaryStatistics();
|
||||
assertEquals(sumStats.getMax(), stats.max().longValue());
|
||||
assertEquals(sumStats.getMin(), stats.min().longValue());
|
||||
assertEquals(sumStats.getAverage(), stats.mean(), 0.00001);
|
||||
assertEquals(sumStats.getSum(), stats.sum().longValue());
|
||||
double variance = computeVariance(docValues, stats.mean, stats.count());
|
||||
assertEquals(variance, stats.variance(), 0.00001);
|
||||
assertEquals(Math.sqrt(variance), stats.stdev(), 0.00001);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testRandomDocsWithDoubleValues() throws IOException {
|
||||
public void testDocsWithDoubleValues() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter indexWriter = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
String field = "numeric";
|
||||
|
@ -139,9 +172,14 @@ public class TestDocValuesStatsCollector extends LuceneTestCase {
|
|||
assertEquals(expCount, stats.count());
|
||||
assertEquals(getZeroValues(docValues).count() - reader.numDeletedDocs(), stats.missing());
|
||||
if (stats.count() > 0) {
|
||||
assertEquals(getPositiveValues(docValues).max().getAsDouble(), stats.max().doubleValue(), 0.00001);
|
||||
assertEquals(getPositiveValues(docValues).min().getAsDouble(), stats.min().doubleValue(), 0.00001);
|
||||
assertEquals(getPositiveValues(docValues).average().getAsDouble(), stats.mean(), 0.00001);
|
||||
DoubleSummaryStatistics sumStats = getPositiveValues(docValues).summaryStatistics();
|
||||
assertEquals(sumStats.getMax(), stats.max().doubleValue(), 0.00001);
|
||||
assertEquals(sumStats.getMin(), stats.min().doubleValue(), 0.00001);
|
||||
assertEquals(sumStats.getAverage(), stats.mean(), 0.00001);
|
||||
assertEquals(sumStats.getSum(), stats.sum(), 0.00001);
|
||||
double variance = computeVariance(docValues, stats.mean, stats.count());
|
||||
assertEquals(variance, stats.variance(), 0.00001);
|
||||
assertEquals(Math.sqrt(variance), stats.stdev(), 0.00001);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -163,4 +201,12 @@ public class TestDocValuesStatsCollector extends LuceneTestCase {
|
|||
return Arrays.stream(docValues).filter(v -> v == 0);
|
||||
}
|
||||
|
||||
private static double computeVariance(long[] values, double mean, int count) {
|
||||
return getPositiveValues(values).mapToDouble(v -> (v - mean) * (v-mean)).sum() / count;
|
||||
}
|
||||
|
||||
private static double computeVariance(double[] values, double mean, int count) {
|
||||
return getPositiveValues(values).map(v -> (v - mean) * (v-mean)).sum() / count;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue