add sum of squares, variance, and std deviation for statistical facet

This commit is contained in:
kimchy 2010-06-08 10:39:45 +03:00
parent d5ff6a7cd4
commit 874993557c
4 changed files with 92 additions and 21 deletions

View File

@ -40,16 +40,19 @@ public class InternalStatisticalFacet implements StatisticalFacet, InternalFacet
private double total; private double total;
private double sumOfSquares;
private long count; private long count;
private InternalStatisticalFacet() { private InternalStatisticalFacet() {
} }
public InternalStatisticalFacet(String name, double min, double max, double total, long count) { public InternalStatisticalFacet(String name, double min, double max, double total, double sumOfSquares, long count) {
this.name = name; this.name = name;
this.min = min; this.min = min;
this.max = max; this.max = max;
this.total = total; this.total = total;
this.sumOfSquares = sumOfSquares;
this.count = count; this.count = count;
} }
@ -85,6 +88,14 @@ public class InternalStatisticalFacet implements StatisticalFacet, InternalFacet
return total(); return total();
} }
@Override public double sumOfSquares() {
return this.sumOfSquares;
}
@Override public double getSumOfSquares() {
return sumOfSquares();
}
@Override public double mean() { @Override public double mean() {
return total / count; return total / count;
} }
@ -109,10 +120,27 @@ public class InternalStatisticalFacet implements StatisticalFacet, InternalFacet
return max(); return max();
} }
public double variance() {
return (sumOfSquares - ((total * total) / count)) / count;
}
public double getVariance() {
return variance();
}
public double stdDeviation() {
return Math.sqrt(variance());
}
public double getStdDeviation() {
return stdDeviation();
}
@Override public Facet aggregate(Iterable<Facet> facets) { @Override public Facet aggregate(Iterable<Facet> facets) {
double min = Double.MAX_VALUE; double min = Double.NaN;
double max = Double.MIN_VALUE; double max = Double.NaN;
double total = 0; double total = 0;
double sumOfSquares = 0;
long count = 0; long count = 0;
for (Facet facet : facets) { for (Facet facet : facets) {
@ -120,26 +148,31 @@ public class InternalStatisticalFacet implements StatisticalFacet, InternalFacet
continue; continue;
} }
InternalStatisticalFacet statsFacet = (InternalStatisticalFacet) facet; InternalStatisticalFacet statsFacet = (InternalStatisticalFacet) facet;
if (statsFacet.min() < min) { if (statsFacet.min() < min || Double.isNaN(min)) {
min = statsFacet.min(); min = statsFacet.min();
} }
if (statsFacet.max() > max) { if (statsFacet.max() > max || Double.isNaN(max)) {
max = statsFacet.max(); max = statsFacet.max();
} }
total += statsFacet.total(); total += statsFacet.total();
sumOfSquares += statsFacet.sumOfSquares();
count += statsFacet.count(); count += statsFacet.count();
} }
return new InternalStatisticalFacet(name, min, max, total, count); return new InternalStatisticalFacet(name, min, max, total, sumOfSquares, count);
} }
@Override public void toXContent(XContentBuilder builder, Params params) throws IOException { @Override public void toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(name); builder.startObject(name);
builder.field("_type", "statistical"); builder.field("_type", "statistical");
builder.field("count", count); builder.field("count", count());
builder.field("total", total); builder.field("total", total());
builder.field("min", min); builder.field("min", min());
builder.field("max", max); builder.field("max", max());
builder.field("mean", mean());
builder.field("sum_of_squares", sumOfSquares());
builder.field("variance", variance());
builder.field("std_deviation", stdDeviation());
builder.endObject(); builder.endObject();
} }

View File

@ -48,6 +48,16 @@ public interface StatisticalFacet extends Facet {
*/ */
double getTotal(); double getTotal();
/**
* The sum of squares of the values.
*/
double sumOfSquares();
/**
* The sum of squares of the values.
*/
double getSumOfSquares();
/** /**
* The mean (average) of the values. * The mean (average) of the values.
*/ */
@ -77,4 +87,24 @@ public interface StatisticalFacet extends Facet {
* The maximum value. * The maximum value.
*/ */
double getMax(); double getMax();
/**
* Variance of the values.
*/
double variance();
/**
* Variance of the values.
*/
double getVariance();
/**
* Standard deviation of the values.
*/
double stdDeviation();
/**
* Standard deviation of the values.
*/
double getStdDeviation();
} }

View File

@ -79,44 +79,51 @@ public class StatisticalFacetCollector extends FacetCollector {
} }
@Override public Facet facet() { @Override public Facet facet() {
return new InternalStatisticalFacet(name, statsProc.min(), statsProc.max(), statsProc.total(), statsProc.count()); return new InternalStatisticalFacet(name, statsProc.min(), statsProc.max(), statsProc.total(), statsProc.sumOfSquares(), statsProc.count());
} }
public static class StatsProc implements NumericFieldData.DoubleValueInDocProc { public static class StatsProc implements NumericFieldData.DoubleValueInDocProc {
private double min = Double.MAX_VALUE; private double min = Double.NaN;
private double max = Double.MIN_VALUE; private double max = Double.NaN;
private double total = 0; private double total = 0;
private double sumOfSquares = 0.0;
private long count; private long count;
@Override public void onValue(int docId, double value) { @Override public void onValue(int docId, double value) {
count++; if (value < min || Double.isNaN(min)) {
total += value;
if (value < min) {
min = value; min = value;
} }
if (value > max) { if (value > max || Double.isNaN(max)) {
max = value; max = value;
} }
sumOfSquares += value * value;
total += value;
count++;
} }
public double min() { public final double min() {
return min; return min;
} }
public double max() { public final double max() {
return max; return max;
} }
public double total() { public final double total() {
return total; return total;
} }
public long count() { public final long count() {
return count; return count;
} }
public final double sumOfSquares() {
return sumOfSquares;
}
} }
} }

View File

@ -124,6 +124,7 @@ public class SimpleFacetsTests extends AbstractNodesTests {
assertThat(facet.min(), equalTo(1d)); assertThat(facet.min(), equalTo(1d));
assertThat(facet.max(), equalTo(2d)); assertThat(facet.max(), equalTo(2d));
assertThat(facet.mean(), equalTo(1.5d)); assertThat(facet.mean(), equalTo(1.5d));
assertThat(facet.sumOfSquares(), equalTo(5d));
facet = searchResponse.facets().facet(StatisticalFacet.class, "stats2"); facet = searchResponse.facets().facet(StatisticalFacet.class, "stats2");
assertThat(facet.name(), equalTo(facet.name())); assertThat(facet.name(), equalTo(facet.name()));