LUCENE-7590: add sum, variance and stdev stats to NumericDVStats

This commit is contained in:
Shai Erera 2016-12-15 12:52:37 +02:00
parent e4f31fab2f
commit 295cab7216
3 changed files with 95 additions and 9 deletions

View File

@ -67,6 +67,9 @@ New features
* LUCENE-7466: Added AxiomaticSimilarity. (Peilin Yang via Tommaso Teofili) * LUCENE-7466: Added AxiomaticSimilarity. (Peilin Yang via Tommaso Teofili)
* LUCENE-7590: Added DocValuesStatsCollector to compute statistics on DocValues
fields. (Shai Erera)
Bug Fixes Bug Fixes
* LUCENE-7547: JapaneseTokenizerFactory was failing to close the * LUCENE-7547: JapaneseTokenizerFactory was failing to close the

View File

@ -98,6 +98,7 @@ public abstract class DocValuesStats<T> {
public static abstract class NumericDocValuesStats<T extends Number> extends DocValuesStats<T> { public static abstract class NumericDocValuesStats<T extends Number> extends DocValuesStats<T> {
protected double mean = 0.0; protected double mean = 0.0;
protected double variance = 0.0;
protected NumericDocValues ndv; protected NumericDocValues ndv;
@ -116,15 +117,32 @@ public abstract class DocValuesStats<T> {
return ndv.advanceExact(doc); return ndv.advanceExact(doc);
} }
/** The mean of all values of the field. Undefined when {@link #count} is zero. */ /** The mean of all values of the field. */
public final double mean() { public final double mean() {
return mean; return mean;
} }
/** Returns the variance of all values of the field. */
public final double variance() {
int count = count();
return count > 0 ? variance / count : 0;
}
/** Returns the stdev of all values of the field. */
public final double stdev() {
return Math.sqrt(variance());
}
/** Returns the sum of values of the field. Note that if the values are large, the {@code sum} might overflow. */
public abstract T sum();
} }
/** Holds DocValues statistics for a numeric field storing {@code long} values. */ /** Holds DocValues statistics for a numeric field storing {@code long} values. */
public static final class LongDocValuesStats extends NumericDocValuesStats<Long> { public static final class LongDocValuesStats extends NumericDocValuesStats<Long> {
// To avoid boxing 'long' to 'Long' while the sum is computed, declare it as private variable.
private long sum = 0;
public LongDocValuesStats(String field) { public LongDocValuesStats(String field) {
super(field, Long.MAX_VALUE, Long.MIN_VALUE); super(field, Long.MAX_VALUE, Long.MIN_VALUE);
} }
@ -138,13 +156,24 @@ public abstract class DocValuesStats<T> {
if (val < min) { if (val < min) {
min = val; min = val;
} }
sum += val;
double oldMean = mean;
mean += (val - mean) / count; mean += (val - mean) / count;
variance += (val - mean) * (val - oldMean);
}
@Override
public Long sum() {
return sum;
} }
} }
/** Holds DocValues statistics for a numeric field storing {@code double} values. */ /** Holds DocValues statistics for a numeric field storing {@code double} values. */
public static final class DoubleDocValuesStats extends NumericDocValuesStats<Double> { public static final class DoubleDocValuesStats extends NumericDocValuesStats<Double> {
// To avoid boxing 'double' to 'Double' while the sum is computed, declare it as private variable.
private double sum = 0;
public DoubleDocValuesStats(String field) { public DoubleDocValuesStats(String field) {
super(field, Double.MAX_VALUE, Double.MIN_VALUE); super(field, Double.MAX_VALUE, Double.MIN_VALUE);
} }
@ -158,7 +187,15 @@ public abstract class DocValuesStats<T> {
if (Double.compare(val, min) < 0) { if (Double.compare(val, min) < 0) {
min = val; min = val;
} }
sum += val;
double oldMean = mean;
mean += (val - mean) / count; mean += (val - mean) / count;
variance += (val - mean) * (val - oldMean);
}
@Override
public Double sum() {
return sum;
} }
} }

View File

@ -18,6 +18,8 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.DoubleSummaryStatistics;
import java.util.LongSummaryStatistics;
import java.util.stream.DoubleStream; import java.util.stream.DoubleStream;
import java.util.stream.LongStream; import java.util.stream.LongStream;
@ -57,7 +59,33 @@ public class TestDocValuesStatsCollector extends LuceneTestCase {
} }
} }
public void testRandomDocsWithLongValues() throws IOException { public void testOneDoc() throws IOException {
try (Directory dir = newDirectory();
IndexWriter indexWriter = new IndexWriter(dir, newIndexWriterConfig())) {
String field = "numeric";
Document doc = new Document();
doc.add(new NumericDocValuesField(field, 1));
doc.add(new StringField("id", "doc1", Store.NO));
indexWriter.addDocument(doc);
try (DirectoryReader reader = DirectoryReader.open(indexWriter)) {
IndexSearcher searcher = new IndexSearcher(reader);
LongDocValuesStats stats = new LongDocValuesStats(field);
searcher.search(new MatchAllDocsQuery(), new DocValuesStatsCollector(stats));
assertEquals(1, stats.count());
assertEquals(0, stats.missing());
assertEquals(1, stats.max().longValue());
assertEquals(1, stats.min().longValue());
assertEquals(1, stats.sum().longValue());
assertEquals(1, stats.mean(), 0.0001);
assertEquals(0, stats.variance(), 0.0001);
assertEquals(0, stats.stdev(), 0.0001);
}
}
}
public void testDocsWithLongValues() throws IOException {
try (Directory dir = newDirectory(); try (Directory dir = newDirectory();
IndexWriter indexWriter = new IndexWriter(dir, newIndexWriterConfig())) { IndexWriter indexWriter = new IndexWriter(dir, newIndexWriterConfig())) {
String field = "numeric"; String field = "numeric";
@ -94,15 +122,20 @@ public class TestDocValuesStatsCollector extends LuceneTestCase {
assertEquals(expCount, stats.count()); assertEquals(expCount, stats.count());
assertEquals(getZeroValues(docValues).count() - reader.numDeletedDocs(), stats.missing()); assertEquals(getZeroValues(docValues).count() - reader.numDeletedDocs(), stats.missing());
if (stats.count() > 0) { if (stats.count() > 0) {
assertEquals(getPositiveValues(docValues).max().getAsLong(), stats.max().longValue()); LongSummaryStatistics sumStats = getPositiveValues(docValues).summaryStatistics();
assertEquals(getPositiveValues(docValues).min().getAsLong(), stats.min().longValue()); assertEquals(sumStats.getMax(), stats.max().longValue());
assertEquals(getPositiveValues(docValues).average().getAsDouble(), stats.mean(), 0.00001); assertEquals(sumStats.getMin(), stats.min().longValue());
assertEquals(sumStats.getAverage(), stats.mean(), 0.00001);
assertEquals(sumStats.getSum(), stats.sum().longValue());
double variance = computeVariance(docValues, stats.mean, stats.count());
assertEquals(variance, stats.variance(), 0.00001);
assertEquals(Math.sqrt(variance), stats.stdev(), 0.00001);
} }
} }
} }
} }
public void testRandomDocsWithDoubleValues() throws IOException { public void testDocsWithDoubleValues() throws IOException {
try (Directory dir = newDirectory(); try (Directory dir = newDirectory();
IndexWriter indexWriter = new IndexWriter(dir, newIndexWriterConfig())) { IndexWriter indexWriter = new IndexWriter(dir, newIndexWriterConfig())) {
String field = "numeric"; String field = "numeric";
@ -139,9 +172,14 @@ public class TestDocValuesStatsCollector extends LuceneTestCase {
assertEquals(expCount, stats.count()); assertEquals(expCount, stats.count());
assertEquals(getZeroValues(docValues).count() - reader.numDeletedDocs(), stats.missing()); assertEquals(getZeroValues(docValues).count() - reader.numDeletedDocs(), stats.missing());
if (stats.count() > 0) { if (stats.count() > 0) {
assertEquals(getPositiveValues(docValues).max().getAsDouble(), stats.max().doubleValue(), 0.00001); DoubleSummaryStatistics sumStats = getPositiveValues(docValues).summaryStatistics();
assertEquals(getPositiveValues(docValues).min().getAsDouble(), stats.min().doubleValue(), 0.00001); assertEquals(sumStats.getMax(), stats.max().doubleValue(), 0.00001);
assertEquals(getPositiveValues(docValues).average().getAsDouble(), stats.mean(), 0.00001); assertEquals(sumStats.getMin(), stats.min().doubleValue(), 0.00001);
assertEquals(sumStats.getAverage(), stats.mean(), 0.00001);
assertEquals(sumStats.getSum(), stats.sum(), 0.00001);
double variance = computeVariance(docValues, stats.mean, stats.count());
assertEquals(variance, stats.variance(), 0.00001);
assertEquals(Math.sqrt(variance), stats.stdev(), 0.00001);
} }
} }
} }
@ -163,4 +201,12 @@ public class TestDocValuesStatsCollector extends LuceneTestCase {
return Arrays.stream(docValues).filter(v -> v == 0); return Arrays.stream(docValues).filter(v -> v == 0);
} }
private static double computeVariance(long[] values, double mean, int count) {
return getPositiveValues(values).mapToDouble(v -> (v - mean) * (v-mean)).sum() / count;
}
private static double computeVariance(double[] values, double mean, int count) {
return getPositiveValues(values).map(v -> (v - mean) * (v-mean)).sum() / count;
}
} }