variance aggregator support for double columns (#9076)

* variance aggregator support for double column instead of casting to float

* docs

* everything in its right place

* checkstyle

* adjustments
This commit is contained in:
Clint Wylie 2020-02-12 09:32:42 -08:00 committed by GitHub
parent d268ff7297
commit c3ebb5eb65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 134 additions and 35 deletions

View File

@ -50,7 +50,7 @@ To use this feature, an "variance" aggregator must be included at indexing time.
The ingestion aggregator can only apply to numeric values. If you use "variance"
then any input rows missing the value will be considered to have a value of 0.
User can specify expected input type as one of "float", "long", "variance" for ingestion, which is by default "float".
User can specify expected input type as one of "float", "double", "long", "variance" for ingestion, which is by default "float".
```json
{

View File

@ -21,6 +21,7 @@ package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.segment.BaseDoubleColumnValueSelector;
import org.apache.druid.segment.BaseFloatColumnValueSelector;
import org.apache.druid.segment.BaseLongColumnValueSelector;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
@ -31,10 +32,6 @@ public abstract class VarianceAggregator implements Aggregator
{
protected final VarianceAggregatorCollector holder = new VarianceAggregatorCollector();
public VarianceAggregator()
{
}
@Override
public Object get()
{
@ -66,37 +63,56 @@ public abstract class VarianceAggregator implements Aggregator
public static final class FloatVarianceAggregator extends VarianceAggregator
{
private final boolean noNulls = NullHandling.replaceWithDefault();
private final BaseFloatColumnValueSelector selector;
public FloatVarianceAggregator(BaseFloatColumnValueSelector selector)
{
super();
this.selector = selector;
}
@Override
public void aggregate()
{
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
if (noNulls || !selector.isNull()) {
holder.add(selector.getFloat());
}
}
}
public static final class DoubleVarianceAggregator extends VarianceAggregator
{
private final boolean noNulls = NullHandling.replaceWithDefault();
private final BaseDoubleColumnValueSelector selector;
public DoubleVarianceAggregator(BaseDoubleColumnValueSelector selector)
{
this.selector = selector;
}
@Override
public void aggregate()
{
if (noNulls || !selector.isNull()) {
holder.add(selector.getDouble());
}
}
}
public static final class LongVarianceAggregator extends VarianceAggregator
{
private final boolean noNulls = NullHandling.replaceWithDefault();
private final BaseLongColumnValueSelector selector;
public LongVarianceAggregator(BaseLongColumnValueSelector selector)
{
super();
this.selector = selector;
}
@Override
public void aggregate()
{
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
if (noNulls || !selector.isNull()) {
holder.add(selector.getLong());
}
}
@ -108,7 +124,6 @@ public abstract class VarianceAggregator implements Aggregator
public ObjectVarianceAggregator(BaseObjectColumnValueSelector<?> selector)
{
super();
this.selector = selector;
}

View File

@ -134,6 +134,17 @@ public class VarianceAggregatorCollector
return this;
}
public VarianceAggregatorCollector add(double v)
{
count++;
sum += v;
if (count > 1) {
double t = count * v - sum;
nvariance += (t * t) / ((double) count * (count - 1));
}
return this;
}
public VarianceAggregatorCollector add(long v)
{
count++;

View File

@ -38,6 +38,8 @@ import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@ -51,10 +53,12 @@ import java.util.Objects;
@JsonTypeName("variance")
public class VarianceAggregatorFactory extends AggregatorFactory
{
private static final String VARIANCE_TYPE_NAME = "variance";
protected final String fieldName;
protected final String name;
@Nullable
protected final String estimator;
@Nullable
private final String inputType;
protected final boolean isVariancePop;
@ -74,7 +78,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory
this.fieldName = fieldName;
this.estimator = estimator;
this.isVariancePop = VarianceAggregatorCollector.isVariancePop(estimator);
this.inputType = inputType == null ? "float" : inputType;
this.inputType = inputType;
}
public VarianceAggregatorFactory(String name, String fieldName)
@ -85,7 +89,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory
@Override
public String getTypeName()
{
return "variance";
return VARIANCE_TYPE_NAME;
}
@Override
@ -102,15 +106,21 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return NoopAggregator.instance();
}
if ("float".equalsIgnoreCase(inputType)) {
final String type = getTypeString(metricFactory);
if (ValueType.FLOAT.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.FloatVarianceAggregator(selector);
} else if ("long".equalsIgnoreCase(inputType)) {
} else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.DoubleVarianceAggregator(selector);
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.LongVarianceAggregator(selector);
} else if ("variance".equalsIgnoreCase(inputType)) {
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) {
return new VarianceAggregator.ObjectVarianceAggregator(selector);
}
throw new IAE(
"Incompatible type for metric[%s], expected a float, long or variance, got a %s", fieldName, inputType
"Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
fieldName,
inputType
);
}
@ -121,15 +131,21 @@ public class VarianceAggregatorFactory extends AggregatorFactory
if (selector instanceof NilColumnValueSelector) {
return NoopBufferAggregator.instance();
}
if ("float".equalsIgnoreCase(inputType)) {
final String type = getTypeString(metricFactory);
if (ValueType.FLOAT.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.FloatVarianceAggregator(selector);
} else if ("long".equalsIgnoreCase(inputType)) {
} else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.DoubleVarianceAggregator(selector);
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.LongVarianceAggregator(selector);
} else if ("variance".equalsIgnoreCase(inputType)) {
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.ObjectVarianceAggregator(selector);
}
throw new IAE(
"Incompatible type for metric[%s], expected a float, long or variance, got a %s", fieldName, inputType
"Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
fieldName,
inputType
);
}
@ -249,7 +265,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory
@JsonProperty
public String getInputType()
{
return inputType;
return inputType == null ? StringUtils.toLowerCase(ValueType.FLOAT.name()) : inputType;
}
@Override
@ -304,4 +320,19 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return Objects.hash(fieldName, name, estimator, inputType, isVariancePop);
}
private String getTypeString(ColumnSelectorFactory metricFactory)
{
String type = inputType;
if (type == null) {
ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName);
if (capabilities != null) {
type = StringUtils.toLowerCase(capabilities.getType().name());
} else {
type = StringUtils.toLowerCase(ValueType.FLOAT.name());
}
}
return type;
}
}

View File

@ -23,9 +23,11 @@ import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseDoubleColumnValueSelector;
import org.apache.druid.segment.BaseFloatColumnValueSelector;
import org.apache.druid.segment.BaseLongColumnValueSelector;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
import java.nio.ByteBuffer;
/**
@ -79,6 +81,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
public static final class FloatVarianceAggregator extends VarianceBufferAggregator
{
private final boolean noNulls = NullHandling.replaceWithDefault();
private final BaseFloatColumnValueSelector selector;
public FloatVarianceAggregator(BaseFloatColumnValueSelector selector)
@ -89,7 +92,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
@Override
public void aggregate(ByteBuffer buf, int position)
{
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
if (noNulls || !selector.isNull()) {
float v = selector.getFloat();
long count = buf.getLong(position + COUNT_OFFSET) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v;
@ -110,8 +113,43 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
}
}
public static final class DoubleVarianceAggregator extends VarianceBufferAggregator
{
private final boolean noNulls = NullHandling.replaceWithDefault();
private final BaseDoubleColumnValueSelector selector;
public DoubleVarianceAggregator(BaseDoubleColumnValueSelector selector)
{
this.selector = selector;
}
@Override
public void aggregate(ByteBuffer buf, int position)
{
if (noNulls || !selector.isNull()) {
double v = selector.getDouble();
long count = buf.getLong(position + COUNT_OFFSET) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v;
buf.putLong(position, count);
buf.putDouble(position + SUM_OFFSET, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
buf.putDouble(position + NVARIANCE_OFFSET, variance);
}
}
}
@Override
public void inspectRuntimeShape(RuntimeShapeInspector inspector)
{
inspector.visit("selector", selector);
}
}
public static final class LongVarianceAggregator extends VarianceBufferAggregator
{
private final boolean noNulls = NullHandling.replaceWithDefault();
private final BaseLongColumnValueSelector selector;
public LongVarianceAggregator(BaseLongColumnValueSelector selector)
@ -122,7 +160,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
@Override
public void aggregate(ByteBuffer buf, int position)
{
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
if (noNulls || !selector.isNull()) {
long v = selector.getLong();
long count = buf.getLong(position + COUNT_OFFSET) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v;

View File

@ -100,14 +100,17 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
virtualColumns.add(virtualColumn);
}
if (inputType == ValueType.LONG) {
inputTypeName = "long";
} else if (inputType == ValueType.FLOAT || inputType == ValueType.DOUBLE) {
inputTypeName = "float";
} else {
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType);
switch (inputType) {
case LONG:
case DOUBLE:
case FLOAT:
inputTypeName = StringUtils.toLowerCase(inputType.name());
break;
default:
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType);
}
if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
estimator = "population";
} else {

View File

@ -67,6 +67,7 @@ public class VarianceAggregatorTest extends InitializedNullHandlingTest
selector = new TestFloatColumnSelector(values);
colSelectorFactory = EasyMock.createMock(ColumnSelectorFactory.class);
EasyMock.expect(colSelectorFactory.makeColumnValueSelector("nilly")).andReturn(selector);
EasyMock.expect(colSelectorFactory.getColumnCapabilities("nilly")).andReturn(null);
EasyMock.replay(colSelectorFactory);
}

View File

@ -201,7 +201,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
if (raw != null) {
if (raw instanceof Double) {
double v = ((Double) raw).doubleValue() * multiply;
holder.add((float) v);
holder.add(v);
} else if (raw instanceof Float) {
float v = ((Float) raw).floatValue() * multiply;
holder.add(v);
@ -263,7 +263,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
)
@ -318,7 +318,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
)
@ -373,7 +373,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
)
@ -435,7 +435,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
)
@ -501,7 +501,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "v0", "sample", "float"),
new VarianceAggregatorFactory("a0:agg", "v0", "sample", "double"),
new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long")
)