mirror of https://github.com/apache/druid.git
variance aggregator support for double columns (#9076)
* variance aggregator support for double column instead of casting to float * docs * everything in its right place * checkstyle * adjustments
This commit is contained in:
parent
d268ff7297
commit
c3ebb5eb65
|
@ -50,7 +50,7 @@ To use this feature, an "variance" aggregator must be included at indexing time.
|
|||
The ingestion aggregator can only apply to numeric values. If you use "variance"
|
||||
then any input rows missing the value will be considered to have a value of 0.
|
||||
|
||||
User can specify expected input type as one of "float", "long", "variance" for ingestion, which is by default "float".
|
||||
User can specify expected input type as one of "float", "double", "long", "variance" for ingestion, which is by default "float".
|
||||
|
||||
```json
|
||||
{
|
||||
|
|
|
@ -21,6 +21,7 @@ package org.apache.druid.query.aggregation.variance;
|
|||
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.query.aggregation.Aggregator;
|
||||
import org.apache.druid.segment.BaseDoubleColumnValueSelector;
|
||||
import org.apache.druid.segment.BaseFloatColumnValueSelector;
|
||||
import org.apache.druid.segment.BaseLongColumnValueSelector;
|
||||
import org.apache.druid.segment.BaseObjectColumnValueSelector;
|
||||
|
@ -31,10 +32,6 @@ public abstract class VarianceAggregator implements Aggregator
|
|||
{
|
||||
protected final VarianceAggregatorCollector holder = new VarianceAggregatorCollector();
|
||||
|
||||
public VarianceAggregator()
|
||||
{
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object get()
|
||||
{
|
||||
|
@ -66,37 +63,56 @@ public abstract class VarianceAggregator implements Aggregator
|
|||
|
||||
public static final class FloatVarianceAggregator extends VarianceAggregator
|
||||
{
|
||||
private final boolean noNulls = NullHandling.replaceWithDefault();
|
||||
private final BaseFloatColumnValueSelector selector;
|
||||
|
||||
public FloatVarianceAggregator(BaseFloatColumnValueSelector selector)
|
||||
{
|
||||
super();
|
||||
this.selector = selector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void aggregate()
|
||||
{
|
||||
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
|
||||
if (noNulls || !selector.isNull()) {
|
||||
holder.add(selector.getFloat());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static final class DoubleVarianceAggregator extends VarianceAggregator
|
||||
{
|
||||
private final boolean noNulls = NullHandling.replaceWithDefault();
|
||||
private final BaseDoubleColumnValueSelector selector;
|
||||
|
||||
public DoubleVarianceAggregator(BaseDoubleColumnValueSelector selector)
|
||||
{
|
||||
this.selector = selector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void aggregate()
|
||||
{
|
||||
if (noNulls || !selector.isNull()) {
|
||||
holder.add(selector.getDouble());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static final class LongVarianceAggregator extends VarianceAggregator
|
||||
{
|
||||
private final boolean noNulls = NullHandling.replaceWithDefault();
|
||||
private final BaseLongColumnValueSelector selector;
|
||||
|
||||
public LongVarianceAggregator(BaseLongColumnValueSelector selector)
|
||||
{
|
||||
super();
|
||||
this.selector = selector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void aggregate()
|
||||
{
|
||||
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
|
||||
if (noNulls || !selector.isNull()) {
|
||||
holder.add(selector.getLong());
|
||||
}
|
||||
}
|
||||
|
@ -108,7 +124,6 @@ public abstract class VarianceAggregator implements Aggregator
|
|||
|
||||
public ObjectVarianceAggregator(BaseObjectColumnValueSelector<?> selector)
|
||||
{
|
||||
super();
|
||||
this.selector = selector;
|
||||
}
|
||||
|
||||
|
|
|
@ -134,6 +134,17 @@ public class VarianceAggregatorCollector
|
|||
return this;
|
||||
}
|
||||
|
||||
public VarianceAggregatorCollector add(double v)
|
||||
{
|
||||
count++;
|
||||
sum += v;
|
||||
if (count > 1) {
|
||||
double t = count * v - sum;
|
||||
nvariance += (t * t) / ((double) count * (count - 1));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public VarianceAggregatorCollector add(long v)
|
||||
{
|
||||
count++;
|
||||
|
|
|
@ -38,6 +38,8 @@ import org.apache.druid.query.cache.CacheKeyBuilder;
|
|||
import org.apache.druid.segment.ColumnSelectorFactory;
|
||||
import org.apache.druid.segment.ColumnValueSelector;
|
||||
import org.apache.druid.segment.NilColumnValueSelector;
|
||||
import org.apache.druid.segment.column.ColumnCapabilities;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -51,10 +53,12 @@ import java.util.Objects;
|
|||
@JsonTypeName("variance")
|
||||
public class VarianceAggregatorFactory extends AggregatorFactory
|
||||
{
|
||||
private static final String VARIANCE_TYPE_NAME = "variance";
|
||||
protected final String fieldName;
|
||||
protected final String name;
|
||||
@Nullable
|
||||
protected final String estimator;
|
||||
@Nullable
|
||||
private final String inputType;
|
||||
|
||||
protected final boolean isVariancePop;
|
||||
|
@ -74,7 +78,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory
|
|||
this.fieldName = fieldName;
|
||||
this.estimator = estimator;
|
||||
this.isVariancePop = VarianceAggregatorCollector.isVariancePop(estimator);
|
||||
this.inputType = inputType == null ? "float" : inputType;
|
||||
this.inputType = inputType;
|
||||
}
|
||||
|
||||
public VarianceAggregatorFactory(String name, String fieldName)
|
||||
|
@ -85,7 +89,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory
|
|||
@Override
|
||||
public String getTypeName()
|
||||
{
|
||||
return "variance";
|
||||
return VARIANCE_TYPE_NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -102,15 +106,21 @@ public class VarianceAggregatorFactory extends AggregatorFactory
|
|||
return NoopAggregator.instance();
|
||||
}
|
||||
|
||||
if ("float".equalsIgnoreCase(inputType)) {
|
||||
final String type = getTypeString(metricFactory);
|
||||
|
||||
if (ValueType.FLOAT.name().equalsIgnoreCase(type)) {
|
||||
return new VarianceAggregator.FloatVarianceAggregator(selector);
|
||||
} else if ("long".equalsIgnoreCase(inputType)) {
|
||||
} else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) {
|
||||
return new VarianceAggregator.DoubleVarianceAggregator(selector);
|
||||
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
|
||||
return new VarianceAggregator.LongVarianceAggregator(selector);
|
||||
} else if ("variance".equalsIgnoreCase(inputType)) {
|
||||
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) {
|
||||
return new VarianceAggregator.ObjectVarianceAggregator(selector);
|
||||
}
|
||||
throw new IAE(
|
||||
"Incompatible type for metric[%s], expected a float, long or variance, got a %s", fieldName, inputType
|
||||
"Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
|
||||
fieldName,
|
||||
inputType
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -121,15 +131,21 @@ public class VarianceAggregatorFactory extends AggregatorFactory
|
|||
if (selector instanceof NilColumnValueSelector) {
|
||||
return NoopBufferAggregator.instance();
|
||||
}
|
||||
if ("float".equalsIgnoreCase(inputType)) {
|
||||
final String type = getTypeString(metricFactory);
|
||||
|
||||
if (ValueType.FLOAT.name().equalsIgnoreCase(type)) {
|
||||
return new VarianceBufferAggregator.FloatVarianceAggregator(selector);
|
||||
} else if ("long".equalsIgnoreCase(inputType)) {
|
||||
} else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) {
|
||||
return new VarianceBufferAggregator.DoubleVarianceAggregator(selector);
|
||||
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
|
||||
return new VarianceBufferAggregator.LongVarianceAggregator(selector);
|
||||
} else if ("variance".equalsIgnoreCase(inputType)) {
|
||||
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) {
|
||||
return new VarianceBufferAggregator.ObjectVarianceAggregator(selector);
|
||||
}
|
||||
throw new IAE(
|
||||
"Incompatible type for metric[%s], expected a float, long or variance, got a %s", fieldName, inputType
|
||||
"Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
|
||||
fieldName,
|
||||
inputType
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -249,7 +265,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory
|
|||
@JsonProperty
|
||||
public String getInputType()
|
||||
{
|
||||
return inputType;
|
||||
return inputType == null ? StringUtils.toLowerCase(ValueType.FLOAT.name()) : inputType;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -304,4 +320,19 @@ public class VarianceAggregatorFactory extends AggregatorFactory
|
|||
|
||||
return Objects.hash(fieldName, name, estimator, inputType, isVariancePop);
|
||||
}
|
||||
|
||||
private String getTypeString(ColumnSelectorFactory metricFactory)
|
||||
{
|
||||
String type = inputType;
|
||||
if (type == null) {
|
||||
ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName);
|
||||
if (capabilities != null) {
|
||||
type = StringUtils.toLowerCase(capabilities.getType().name());
|
||||
} else {
|
||||
type = StringUtils.toLowerCase(ValueType.FLOAT.name());
|
||||
}
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -23,9 +23,11 @@ import com.google.common.base.Preconditions;
|
|||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.query.aggregation.BufferAggregator;
|
||||
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
|
||||
import org.apache.druid.segment.BaseDoubleColumnValueSelector;
|
||||
import org.apache.druid.segment.BaseFloatColumnValueSelector;
|
||||
import org.apache.druid.segment.BaseLongColumnValueSelector;
|
||||
import org.apache.druid.segment.BaseObjectColumnValueSelector;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
|
@ -79,6 +81,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
|
|||
|
||||
public static final class FloatVarianceAggregator extends VarianceBufferAggregator
|
||||
{
|
||||
private final boolean noNulls = NullHandling.replaceWithDefault();
|
||||
private final BaseFloatColumnValueSelector selector;
|
||||
|
||||
public FloatVarianceAggregator(BaseFloatColumnValueSelector selector)
|
||||
|
@ -89,7 +92,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
|
|||
@Override
|
||||
public void aggregate(ByteBuffer buf, int position)
|
||||
{
|
||||
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
|
||||
if (noNulls || !selector.isNull()) {
|
||||
float v = selector.getFloat();
|
||||
long count = buf.getLong(position + COUNT_OFFSET) + 1;
|
||||
double sum = buf.getDouble(position + SUM_OFFSET) + v;
|
||||
|
@ -110,8 +113,43 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
|
|||
}
|
||||
}
|
||||
|
||||
public static final class DoubleVarianceAggregator extends VarianceBufferAggregator
|
||||
{
|
||||
private final boolean noNulls = NullHandling.replaceWithDefault();
|
||||
private final BaseDoubleColumnValueSelector selector;
|
||||
|
||||
public DoubleVarianceAggregator(BaseDoubleColumnValueSelector selector)
|
||||
{
|
||||
this.selector = selector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void aggregate(ByteBuffer buf, int position)
|
||||
{
|
||||
if (noNulls || !selector.isNull()) {
|
||||
double v = selector.getDouble();
|
||||
long count = buf.getLong(position + COUNT_OFFSET) + 1;
|
||||
double sum = buf.getDouble(position + SUM_OFFSET) + v;
|
||||
buf.putLong(position, count);
|
||||
buf.putDouble(position + SUM_OFFSET, sum);
|
||||
if (count > 1) {
|
||||
double t = count * v - sum;
|
||||
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
|
||||
buf.putDouble(position + NVARIANCE_OFFSET, variance);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void inspectRuntimeShape(RuntimeShapeInspector inspector)
|
||||
{
|
||||
inspector.visit("selector", selector);
|
||||
}
|
||||
}
|
||||
|
||||
public static final class LongVarianceAggregator extends VarianceBufferAggregator
|
||||
{
|
||||
private final boolean noNulls = NullHandling.replaceWithDefault();
|
||||
private final BaseLongColumnValueSelector selector;
|
||||
|
||||
public LongVarianceAggregator(BaseLongColumnValueSelector selector)
|
||||
|
@ -122,7 +160,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
|
|||
@Override
|
||||
public void aggregate(ByteBuffer buf, int position)
|
||||
{
|
||||
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
|
||||
if (noNulls || !selector.isNull()) {
|
||||
long v = selector.getLong();
|
||||
long count = buf.getLong(position + COUNT_OFFSET) + 1;
|
||||
double sum = buf.getDouble(position + SUM_OFFSET) + v;
|
||||
|
|
|
@ -100,14 +100,17 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
virtualColumns.add(virtualColumn);
|
||||
}
|
||||
|
||||
if (inputType == ValueType.LONG) {
|
||||
inputTypeName = "long";
|
||||
} else if (inputType == ValueType.FLOAT || inputType == ValueType.DOUBLE) {
|
||||
inputTypeName = "float";
|
||||
} else {
|
||||
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType);
|
||||
switch (inputType) {
|
||||
case LONG:
|
||||
case DOUBLE:
|
||||
case FLOAT:
|
||||
inputTypeName = StringUtils.toLowerCase(inputType.name());
|
||||
break;
|
||||
default:
|
||||
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType);
|
||||
}
|
||||
|
||||
|
||||
if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
|
||||
estimator = "population";
|
||||
} else {
|
||||
|
|
|
@ -67,6 +67,7 @@ public class VarianceAggregatorTest extends InitializedNullHandlingTest
|
|||
selector = new TestFloatColumnSelector(values);
|
||||
colSelectorFactory = EasyMock.createMock(ColumnSelectorFactory.class);
|
||||
EasyMock.expect(colSelectorFactory.makeColumnValueSelector("nilly")).andReturn(selector);
|
||||
EasyMock.expect(colSelectorFactory.getColumnCapabilities("nilly")).andReturn(null);
|
||||
EasyMock.replay(colSelectorFactory);
|
||||
}
|
||||
|
||||
|
|
|
@ -201,7 +201,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
if (raw != null) {
|
||||
if (raw instanceof Double) {
|
||||
double v = ((Double) raw).doubleValue() * multiply;
|
||||
holder.add((float) v);
|
||||
holder.add(v);
|
||||
} else if (raw instanceof Float) {
|
||||
float v = ((Float) raw).floatValue() * multiply;
|
||||
holder.add(v);
|
||||
|
@ -263,7 +263,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
|
||||
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
|
||||
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
|
||||
)
|
||||
|
@ -318,7 +318,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
|
||||
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
|
||||
)
|
||||
|
@ -373,7 +373,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
|
||||
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
|
||||
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
|
||||
)
|
||||
|
@ -435,7 +435,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
|
||||
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
|
||||
)
|
||||
|
@ -501,7 +501,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "v0", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a0:agg", "v0", "sample", "double"),
|
||||
new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long")
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue