diff --git a/docs/development/extensions-core/stats.md b/docs/development/extensions-core/stats.md index fa704ea3018..61cb6981e73 100644 --- a/docs/development/extensions-core/stats.md +++ b/docs/development/extensions-core/stats.md @@ -50,7 +50,7 @@ To use this feature, an "variance" aggregator must be included at indexing time. The ingestion aggregator can only apply to numeric values. If you use "variance" then any input rows missing the value will be considered to have a value of 0. -User can specify expected input type as one of "float", "long", "variance" for ingestion, which is by default "float". +User can specify expected input type as one of "float", "double", "long", "variance" for ingestion, which is by default "float". ```json { diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java index 935a1b83958..5e9c6bf3603 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java @@ -21,6 +21,7 @@ package org.apache.druid.query.aggregation.variance; import org.apache.druid.common.config.NullHandling; import org.apache.druid.query.aggregation.Aggregator; +import org.apache.druid.segment.BaseDoubleColumnValueSelector; import org.apache.druid.segment.BaseFloatColumnValueSelector; import org.apache.druid.segment.BaseLongColumnValueSelector; import org.apache.druid.segment.BaseObjectColumnValueSelector; @@ -31,10 +32,6 @@ public abstract class VarianceAggregator implements Aggregator { protected final VarianceAggregatorCollector holder = new VarianceAggregatorCollector(); - public VarianceAggregator() - { - } - @Override public Object get() { @@ -66,37 +63,56 @@ public abstract class VarianceAggregator implements Aggregator public static final class FloatVarianceAggregator extends VarianceAggregator { + private final boolean noNulls = NullHandling.replaceWithDefault(); private final BaseFloatColumnValueSelector selector; public FloatVarianceAggregator(BaseFloatColumnValueSelector selector) { - super(); this.selector = selector; } @Override public void aggregate() { - if (NullHandling.replaceWithDefault() || !selector.isNull()) { + if (noNulls || !selector.isNull()) { holder.add(selector.getFloat()); } } } + public static final class DoubleVarianceAggregator extends VarianceAggregator + { + private final boolean noNulls = NullHandling.replaceWithDefault(); + private final BaseDoubleColumnValueSelector selector; + + public DoubleVarianceAggregator(BaseDoubleColumnValueSelector selector) + { + this.selector = selector; + } + + @Override + public void aggregate() + { + if (noNulls || !selector.isNull()) { + holder.add(selector.getDouble()); + } + } + } + public static final class LongVarianceAggregator extends VarianceAggregator { + private final boolean noNulls = NullHandling.replaceWithDefault(); private final BaseLongColumnValueSelector selector; public LongVarianceAggregator(BaseLongColumnValueSelector selector) { - super(); this.selector = selector; } @Override public void aggregate() { - if (NullHandling.replaceWithDefault() || !selector.isNull()) { + if (noNulls || !selector.isNull()) { holder.add(selector.getLong()); } } @@ -108,7 +124,6 @@ public abstract class VarianceAggregator implements Aggregator public ObjectVarianceAggregator(BaseObjectColumnValueSelector selector) { - super(); this.selector = selector; } diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java index 860d9855c39..5dbcea763b4 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java @@ -134,6 +134,17 @@ public class VarianceAggregatorCollector return this; } + public VarianceAggregatorCollector add(double v) + { + count++; + sum += v; + if (count > 1) { + double t = count * v - sum; + nvariance += (t * t) / ((double) count * (count - 1)); + } + return this; + } + public VarianceAggregatorCollector add(long v) { count++; diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java index 79bbe96908e..10a04efea6f 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java @@ -38,6 +38,8 @@ import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ValueType; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -51,10 +53,12 @@ import java.util.Objects; @JsonTypeName("variance") public class VarianceAggregatorFactory extends AggregatorFactory { + private static final String VARIANCE_TYPE_NAME = "variance"; protected final String fieldName; protected final String name; @Nullable protected final String estimator; + @Nullable private final String inputType; protected final boolean isVariancePop; @@ -74,7 +78,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory this.fieldName = fieldName; this.estimator = estimator; this.isVariancePop = VarianceAggregatorCollector.isVariancePop(estimator); - this.inputType = inputType == null ? "float" : inputType; + this.inputType = inputType; } public VarianceAggregatorFactory(String name, String fieldName) @@ -85,7 +89,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory @Override public String getTypeName() { - return "variance"; + return VARIANCE_TYPE_NAME; } @Override @@ -102,15 +106,21 @@ public class VarianceAggregatorFactory extends AggregatorFactory return NoopAggregator.instance(); } - if ("float".equalsIgnoreCase(inputType)) { + final String type = getTypeString(metricFactory); + + if (ValueType.FLOAT.name().equalsIgnoreCase(type)) { return new VarianceAggregator.FloatVarianceAggregator(selector); - } else if ("long".equalsIgnoreCase(inputType)) { + } else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) { + return new VarianceAggregator.DoubleVarianceAggregator(selector); + } else if (ValueType.LONG.name().equalsIgnoreCase(type)) { return new VarianceAggregator.LongVarianceAggregator(selector); - } else if ("variance".equalsIgnoreCase(inputType)) { + } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) { return new VarianceAggregator.ObjectVarianceAggregator(selector); } throw new IAE( - "Incompatible type for metric[%s], expected a float, long or variance, got a %s", fieldName, inputType + "Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s", + fieldName, + inputType ); } @@ -121,15 +131,21 @@ public class VarianceAggregatorFactory extends AggregatorFactory if (selector instanceof NilColumnValueSelector) { return NoopBufferAggregator.instance(); } - if ("float".equalsIgnoreCase(inputType)) { + final String type = getTypeString(metricFactory); + + if (ValueType.FLOAT.name().equalsIgnoreCase(type)) { return new VarianceBufferAggregator.FloatVarianceAggregator(selector); - } else if ("long".equalsIgnoreCase(inputType)) { + } else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) { + return new VarianceBufferAggregator.DoubleVarianceAggregator(selector); + } else if (ValueType.LONG.name().equalsIgnoreCase(type)) { return new VarianceBufferAggregator.LongVarianceAggregator(selector); - } else if ("variance".equalsIgnoreCase(inputType)) { + } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) { return new VarianceBufferAggregator.ObjectVarianceAggregator(selector); } throw new IAE( - "Incompatible type for metric[%s], expected a float, long or variance, got a %s", fieldName, inputType + "Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s", + fieldName, + inputType ); } @@ -249,7 +265,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory @JsonProperty public String getInputType() { - return inputType; + return inputType == null ? StringUtils.toLowerCase(ValueType.FLOAT.name()) : inputType; } @Override @@ -304,4 +320,19 @@ public class VarianceAggregatorFactory extends AggregatorFactory return Objects.hash(fieldName, name, estimator, inputType, isVariancePop); } + + private String getTypeString(ColumnSelectorFactory metricFactory) + { + String type = inputType; + if (type == null) { + ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName); + if (capabilities != null) { + type = StringUtils.toLowerCase(capabilities.getType().name()); + } else { + type = StringUtils.toLowerCase(ValueType.FLOAT.name()); + } + } + return type; + } + } diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java index ae3099225bf..51ec0b1de73 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java @@ -23,9 +23,11 @@ import com.google.common.base.Preconditions; import org.apache.druid.common.config.NullHandling; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.BaseDoubleColumnValueSelector; import org.apache.druid.segment.BaseFloatColumnValueSelector; import org.apache.druid.segment.BaseLongColumnValueSelector; import org.apache.druid.segment.BaseObjectColumnValueSelector; + import java.nio.ByteBuffer; /** @@ -79,6 +81,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator public static final class FloatVarianceAggregator extends VarianceBufferAggregator { + private final boolean noNulls = NullHandling.replaceWithDefault(); private final BaseFloatColumnValueSelector selector; public FloatVarianceAggregator(BaseFloatColumnValueSelector selector) @@ -89,7 +92,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator @Override public void aggregate(ByteBuffer buf, int position) { - if (NullHandling.replaceWithDefault() || !selector.isNull()) { + if (noNulls || !selector.isNull()) { float v = selector.getFloat(); long count = buf.getLong(position + COUNT_OFFSET) + 1; double sum = buf.getDouble(position + SUM_OFFSET) + v; @@ -110,8 +113,43 @@ public abstract class VarianceBufferAggregator implements BufferAggregator } } + public static final class DoubleVarianceAggregator extends VarianceBufferAggregator + { + private final boolean noNulls = NullHandling.replaceWithDefault(); + private final BaseDoubleColumnValueSelector selector; + + public DoubleVarianceAggregator(BaseDoubleColumnValueSelector selector) + { + this.selector = selector; + } + + @Override + public void aggregate(ByteBuffer buf, int position) + { + if (noNulls || !selector.isNull()) { + double v = selector.getDouble(); + long count = buf.getLong(position + COUNT_OFFSET) + 1; + double sum = buf.getDouble(position + SUM_OFFSET) + v; + buf.putLong(position, count); + buf.putDouble(position + SUM_OFFSET, sum); + if (count > 1) { + double t = count * v - sum; + double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1)); + buf.putDouble(position + NVARIANCE_OFFSET, variance); + } + } + } + + @Override + public void inspectRuntimeShape(RuntimeShapeInspector inspector) + { + inspector.visit("selector", selector); + } + } + public static final class LongVarianceAggregator extends VarianceBufferAggregator { + private final boolean noNulls = NullHandling.replaceWithDefault(); private final BaseLongColumnValueSelector selector; public LongVarianceAggregator(BaseLongColumnValueSelector selector) @@ -122,7 +160,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator @Override public void aggregate(ByteBuffer buf, int position) { - if (NullHandling.replaceWithDefault() || !selector.isNull()) { + if (noNulls || !selector.isNull()) { long v = selector.getLong(); long count = buf.getLong(position + COUNT_OFFSET) + 1; double sum = buf.getDouble(position + SUM_OFFSET) + v; diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java index f2da37a2333..4cdd665b74b 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java @@ -100,14 +100,17 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator virtualColumns.add(virtualColumn); } - if (inputType == ValueType.LONG) { - inputTypeName = "long"; - } else if (inputType == ValueType.FLOAT || inputType == ValueType.DOUBLE) { - inputTypeName = "float"; - } else { - throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType); + switch (inputType) { + case LONG: + case DOUBLE: + case FLOAT: + inputTypeName = StringUtils.toLowerCase(inputType.name()); + break; + default: + throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType); } + if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) { estimator = "population"; } else { diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorTest.java index 33967f2ed03..6d7c4314890 100644 --- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorTest.java +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorTest.java @@ -67,6 +67,7 @@ public class VarianceAggregatorTest extends InitializedNullHandlingTest selector = new TestFloatColumnSelector(values); colSelectorFactory = EasyMock.createMock(ColumnSelectorFactory.class); EasyMock.expect(colSelectorFactory.makeColumnValueSelector("nilly")).andReturn(selector); + EasyMock.expect(colSelectorFactory.getColumnCapabilities("nilly")).andReturn(null); EasyMock.replay(colSelectorFactory); } diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java index e904f05e578..192788fed35 100644 --- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java @@ -201,7 +201,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest if (raw != null) { if (raw instanceof Double) { double v = ((Double) raw).doubleValue() * multiply; - holder.add((float) v); + holder.add(v); } else if (raw instanceof Float) { float v = ((Float) raw).floatValue() * multiply; holder.add(v); @@ -263,7 +263,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest .granularity(Granularities.ALL) .aggregators( ImmutableList.of( - new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"), + new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"), new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"), new VarianceAggregatorFactory("a2:agg", "l1", "population", "long") ) @@ -318,7 +318,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest .granularity(Granularities.ALL) .aggregators( ImmutableList.of( - new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"), + new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"), new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"), new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long") ) @@ -373,7 +373,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest .granularity(Granularities.ALL) .aggregators( ImmutableList.of( - new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"), + new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"), new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"), new VarianceAggregatorFactory("a2:agg", "l1", "population", "long") ) @@ -435,7 +435,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest .granularity(Granularities.ALL) .aggregators( ImmutableList.of( - new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"), + new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"), new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"), new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long") ) @@ -501,7 +501,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest ) .aggregators( ImmutableList.of( - new VarianceAggregatorFactory("a0:agg", "v0", "sample", "float"), + new VarianceAggregatorFactory("a0:agg", "v0", "sample", "double"), new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"), new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long") )