diff --git a/benchmarks/pom.xml b/benchmarks/pom.xml
index f00b93125c5..4d216855067 100644
--- a/benchmarks/pom.xml
+++ b/benchmarks/pom.xml
@@ -82,6 +82,11 @@
druid-histogram
${project.parent.version}
+
+ org.apache.druid.extensions
+ druid-stats
+ ${project.parent.version}
+
org.apache.druid
druid-core
@@ -172,7 +177,7 @@
org.apache.druid.extensions
druid-protobuf-extensions
- 0.20.0-SNAPSHOT
+ ${project.parent.version}
test
diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/VarianceBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/VarianceBenchmark.java
new file mode 100644
index 00000000000..85b7c4d86fe
--- /dev/null
+++ b/benchmarks/src/test/java/org/apache/druid/benchmark/VarianceBenchmark.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.benchmark;
+
+import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+import java.util.Random;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+
+@State(Scope.Benchmark)
+@Fork(value = 1)
+@Warmup(iterations = 5)
+@Measurement(iterations = 5)
+public class VarianceBenchmark
+{
+ @Param({"128", "256", "512", "1024"})
+ int vectorSize;
+
+ private float[] randomValues;
+
+ @Setup
+ public void setup()
+ {
+ randomValues = new float[vectorSize];
+ Random r = ThreadLocalRandom.current();
+ for (int i = 0; i < vectorSize; i++) {
+ randomValues[i] = r.nextFloat();
+ }
+ }
+
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.NANOSECONDS)
+ public void collectVarianceOneByOne(Blackhole blackhole)
+ {
+ VarianceAggregatorCollector collector = new VarianceAggregatorCollector();
+ for (float v : randomValues) {
+ collector.add(v);
+ }
+ blackhole.consume(collector);
+ }
+
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.NANOSECONDS)
+ public void collectVarianceInBatch(Blackhole blackhole)
+ {
+ double sum = 0, nvariance = 0;
+ for (float v : randomValues) {
+ sum += v;
+ }
+ double mean = sum / randomValues.length;
+ for (float v : randomValues) {
+ nvariance += (v - mean) * (v - mean);
+ }
+ VarianceAggregatorCollector collector = new VarianceAggregatorCollector(randomValues.length, sum, nvariance);
+ blackhole.consume(collector);
+ }
+}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java
index ce0edb04f41..6526a86aa86 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java
@@ -76,6 +76,7 @@ public class VarianceAggregatorCollector
if (other == null || other.count == 0) {
return;
}
+
if (this.count == 0) {
this.nvariance = other.nvariance;
this.count = other.count;
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
index e9b59b48cd6..2894c019588 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
@@ -22,6 +22,7 @@ package org.apache.druid.query.aggregation.variance;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
@@ -35,12 +36,15 @@ import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.NoopAggregator;
import org.apache.druid.query.aggregation.NoopBufferAggregator;
import org.apache.druid.query.aggregation.ObjectAggregateCombiner;
+import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
+import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@@ -83,7 +87,8 @@ public class VarianceAggregatorFactory extends AggregatorFactory
this.inputType = inputType;
}
- public VarianceAggregatorFactory(String name, String fieldName)
+ @VisibleForTesting
+ VarianceAggregatorFactory(String name, String fieldName)
{
this(name, fieldName, null, null);
}
@@ -131,7 +136,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return new VarianceAggregator.DoubleVarianceAggregator(selector);
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.LongVarianceAggregator(selector);
- } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) {
+ } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.ObjectVarianceAggregator(selector);
}
throw new IAE(
@@ -156,16 +161,42 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return new VarianceBufferAggregator.DoubleVarianceAggregator(selector);
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.LongVarianceAggregator(selector);
- } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) {
+ } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.ObjectVarianceAggregator(selector);
}
throw new IAE(
"Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
fieldName,
- inputType
+ type
);
}
+ @Override
+ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory)
+ {
+ final String type = getTypeString(selectorFactory);
+ if (ValueType.FLOAT.name().equalsIgnoreCase(type)) {
+ return new VarianceFloatVectorAggregator(selectorFactory.makeValueSelector(fieldName));
+ } else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) {
+ return new VarianceDoubleVectorAggregator(selectorFactory.makeValueSelector(fieldName));
+ } else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
+ return new VarianceLongVectorAggregator(selectorFactory.makeValueSelector(fieldName));
+ } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
+ return new VarianceObjectVectorAggregator(selectorFactory.makeObjectSelector(fieldName));
+ }
+ throw new IAE(
+ "Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
+ fieldName,
+ type
+ );
+ }
+
+ @Override
+ public boolean canVectorize(ColumnInspector columnInspector)
+ {
+ return true;
+ }
+
@Override
public Object combine(Object lhs, Object rhs)
{
@@ -340,11 +371,11 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return Objects.hash(fieldName, name, estimator, inputType, isVariancePop);
}
- private String getTypeString(ColumnSelectorFactory metricFactory)
+ private String getTypeString(ColumnInspector columnInspector)
{
String type = inputType;
if (type == null) {
- ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName);
+ ColumnCapabilities capabilities = columnInspector.getColumnCapabilities(fieldName);
if (capabilities != null) {
type = StringUtils.toLowerCase(capabilities.getType().name());
} else {
@@ -353,5 +384,4 @@ public class VarianceAggregatorFactory extends AggregatorFactory
}
return type;
}
-
}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
index 51ec0b1de73..065ad2aa3e0 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
@@ -35,25 +35,19 @@ import java.nio.ByteBuffer;
public abstract class VarianceBufferAggregator implements BufferAggregator
{
private static final int COUNT_OFFSET = 0;
- private static final int SUM_OFFSET = Long.BYTES;
+ private static final int SUM_OFFSET = COUNT_OFFSET + Long.BYTES;
private static final int NVARIANCE_OFFSET = SUM_OFFSET + Double.BYTES;
@Override
public void init(final ByteBuffer buf, final int position)
{
- buf.putLong(position + COUNT_OFFSET, 0)
- .putDouble(position + SUM_OFFSET, 0)
- .putDouble(position + NVARIANCE_OFFSET, 0);
+ doInit(buf, position);
}
@Override
- public Object get(final ByteBuffer buf, final int position)
+ public VarianceAggregatorCollector get(final ByteBuffer buf, final int position)
{
- VarianceAggregatorCollector holder = new VarianceAggregatorCollector();
- holder.count = buf.getLong(position);
- holder.sum = buf.getDouble(position + SUM_OFFSET);
- holder.nvariance = buf.getDouble(position + NVARIANCE_OFFSET);
- return holder;
+ return getVarianceCollector(buf, position);
}
@Override
@@ -79,6 +73,51 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{
}
+ public static void doInit(ByteBuffer buf, int position)
+ {
+ buf.putLong(position + COUNT_OFFSET, 0)
+ .putDouble(position + SUM_OFFSET, 0)
+ .putDouble(position + NVARIANCE_OFFSET, 0);
+ }
+
+ public static long getCount(ByteBuffer buf, int position)
+ {
+ return buf.getLong(position + COUNT_OFFSET);
+ }
+
+ public static double getSum(ByteBuffer buf, int position)
+ {
+ return buf.getDouble(position + SUM_OFFSET);
+ }
+
+ public static double getVariance(ByteBuffer buf, int position)
+ {
+ return buf.getDouble(position + NVARIANCE_OFFSET);
+ }
+ public static VarianceAggregatorCollector getVarianceCollector(ByteBuffer buf, int position)
+ {
+ return new VarianceAggregatorCollector(
+ getCount(buf, position),
+ getSum(buf, position),
+ getVariance(buf, position)
+ );
+ }
+
+ public static void writeNVariance(ByteBuffer buf, int position, long count, double sum, double nvariance)
+ {
+ buf.putLong(position + COUNT_OFFSET, count);
+ buf.putDouble(position + SUM_OFFSET, sum);
+ if (count > 1) {
+ buf.putDouble(position + NVARIANCE_OFFSET, nvariance);
+ }
+ }
+
+ public static void writeCountAndSum(ByteBuffer buf, int position, long count, double sum)
+ {
+ buf.putLong(position + COUNT_OFFSET, count);
+ buf.putDouble(position + SUM_OFFSET, sum);
+ }
+
public static final class FloatVarianceAggregator extends VarianceBufferAggregator
{
private final boolean noNulls = NullHandling.replaceWithDefault();
@@ -94,10 +133,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{
if (noNulls || !selector.isNull()) {
float v = selector.getFloat();
- long count = buf.getLong(position + COUNT_OFFSET) + 1;
- double sum = buf.getDouble(position + SUM_OFFSET) + v;
- buf.putLong(position, count);
- buf.putDouble(position + SUM_OFFSET, sum);
+ long count = getCount(buf, position) + 1;
+ double sum = getSum(buf, position) + v;
+ writeCountAndSum(buf, position, count, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@@ -128,10 +166,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{
if (noNulls || !selector.isNull()) {
double v = selector.getDouble();
- long count = buf.getLong(position + COUNT_OFFSET) + 1;
- double sum = buf.getDouble(position + SUM_OFFSET) + v;
- buf.putLong(position, count);
- buf.putDouble(position + SUM_OFFSET, sum);
+ long count = getCount(buf, position) + 1;
+ double sum = getSum(buf, position) + v;
+ writeCountAndSum(buf, position, count, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@@ -162,10 +199,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{
if (noNulls || !selector.isNull()) {
long v = selector.getLong();
- long count = buf.getLong(position + COUNT_OFFSET) + 1;
- double sum = buf.getDouble(position + SUM_OFFSET) + v;
- buf.putLong(position, count);
- buf.putDouble(position + SUM_OFFSET, sum);
+ long count = getCount(buf, position) + 1;
+ double sum = getSum(buf, position) + v;
+ writeCountAndSum(buf, position, count, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@@ -195,7 +231,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{
VarianceAggregatorCollector holder2 = (VarianceAggregatorCollector) selector.getObject();
Preconditions.checkState(holder2 != null);
- long count = buf.getLong(position + COUNT_OFFSET);
+ long count = getCount(buf, position);
if (count == 0) {
buf.putLong(position, holder2.count);
buf.putDouble(position + SUM_OFFSET, holder2.sum);
@@ -203,7 +239,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
return;
}
- double sum = buf.getDouble(position + SUM_OFFSET);
+ double sum = getSum(buf, position);
double nvariance = buf.getDouble(position + NVARIANCE_OFFSET);
final double ratio = count / (double) holder2.count;
@@ -213,9 +249,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
count += holder2.count;
sum += holder2.sum;
- buf.putLong(position, count);
- buf.putDouble(position + SUM_OFFSET, sum);
- buf.putDouble(position + NVARIANCE_OFFSET, nvariance);
+ writeNVariance(buf, position, count, sum, nvariance);
}
@Override
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregator.java
new file mode 100644
index 00000000000..37c87399cf0
--- /dev/null
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregator.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.vector.VectorValueSelector;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+/**
+ * Vectorized implementation of {@link VarianceBufferAggregator} for doubles.
+ */
+public class VarianceDoubleVectorAggregator implements VectorAggregator
+{
+ private final VectorValueSelector selector;
+ private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
+
+ public VarianceDoubleVectorAggregator(VectorValueSelector selector)
+ {
+ this.selector = selector;
+ }
+
+ @Override
+ public void init(ByteBuffer buf, int position)
+ {
+ VarianceBufferAggregator.doInit(buf, position);
+ }
+
+ @Override
+ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
+ {
+ double[] vector = selector.getDoubleVector();
+ long count = 0;
+ double sum = 0, nvariance = 0;
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ count++;
+ sum += vector[i];
+ }
+ }
+ double mean = sum / count;
+ if (count > 1) {
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ nvariance += (vector[i] - mean) * (vector[i] - mean);
+ }
+ }
+ }
+
+ VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
+ VarianceBufferAggregator.getCount(buf, position),
+ VarianceBufferAggregator.getSum(buf, position),
+ VarianceBufferAggregator.getVariance(buf, position)
+ );
+ previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+
+ @Override
+ public void aggregate(
+ ByteBuffer buf,
+ int numRows,
+ int[] positions,
+ @Nullable int[] rows,
+ int positionOffset
+ )
+ {
+ double[] vector = selector.getDoubleVector();
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = 0; i < numRows; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows != null ? rows[i] : i;
+ if (nulls == null || !nulls[row]) {
+ VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
+ previous.add(vector[row]);
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+ }
+ }
+
+ @Nullable
+ @Override
+ public VarianceAggregatorCollector get(ByteBuffer buf, int position)
+ {
+ return VarianceBufferAggregator.getVarianceCollector(buf, position);
+ }
+
+ @Override
+ public void close()
+ {
+ // Nothing to close.
+ }
+}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregator.java
new file mode 100644
index 00000000000..957926e9781
--- /dev/null
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregator.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.vector.VectorValueSelector;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+/**
+ * Vectorized implementation of {@link VarianceBufferAggregator} for floats.
+ */
+public class VarianceFloatVectorAggregator implements VectorAggregator
+{
+ private final VectorValueSelector selector;
+ private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
+
+ public VarianceFloatVectorAggregator(VectorValueSelector selector)
+ {
+ this.selector = selector;
+ }
+
+ @Override
+ public void init(ByteBuffer buf, int position)
+ {
+ VarianceBufferAggregator.doInit(buf, position);
+ }
+
+ @Override
+ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
+ {
+ float[] vector = selector.getFloatVector();
+ long count = 0;
+ double sum = 0, nvariance = 0;
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ count++;
+ sum += vector[i];
+ }
+ }
+ double mean = sum / count;
+ if (count > 1) {
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ nvariance += (vector[i] - mean) * (vector[i] - mean);
+ }
+ }
+ }
+
+ VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
+ VarianceBufferAggregator.getCount(buf, position),
+ VarianceBufferAggregator.getSum(buf, position),
+ VarianceBufferAggregator.getVariance(buf, position)
+ );
+ previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+
+ @Override
+ public void aggregate(
+ ByteBuffer buf,
+ int numRows,
+ int[] positions,
+ @Nullable int[] rows,
+ int positionOffset
+ )
+ {
+ float[] vector = selector.getFloatVector();
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = 0; i < numRows; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows != null ? rows[i] : i;
+ if (nulls == null || !nulls[row]) {
+ VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
+ previous.add(vector[row]);
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+ }
+ }
+
+ @Nullable
+ @Override
+ public VarianceAggregatorCollector get(ByteBuffer buf, int position)
+ {
+ return VarianceBufferAggregator.getVarianceCollector(buf, position);
+ }
+
+ @Override
+ public void close()
+ {
+ // Nothing to close.
+ }
+}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregator.java
new file mode 100644
index 00000000000..69941b658b6
--- /dev/null
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregator.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.vector.VectorValueSelector;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+/**
+ * Vectorized implementation of {@link VarianceBufferAggregator} for longs.
+ */
+public class VarianceLongVectorAggregator implements VectorAggregator
+{
+ private final VectorValueSelector selector;
+ private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
+
+ public VarianceLongVectorAggregator(VectorValueSelector selector)
+ {
+ this.selector = selector;
+ }
+
+ @Override
+ public void init(ByteBuffer buf, int position)
+ {
+ VarianceBufferAggregator.doInit(buf, position);
+ }
+
+ @Override
+ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
+ {
+ long[] vector = selector.getLongVector();
+ long count = 0;
+ double sum = 0, nvariance = 0;
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ count++;
+ sum += vector[i];
+ }
+ }
+ double mean = sum / count;
+ if (count > 1) {
+ for (int i = startRow; i < endRow; i++) {
+ if (nulls == null || !nulls[i]) {
+ nvariance += (vector[i] - mean) * (vector[i] - mean);
+ }
+ }
+ }
+
+ VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
+ VarianceBufferAggregator.getCount(buf, position),
+ VarianceBufferAggregator.getSum(buf, position),
+ VarianceBufferAggregator.getVariance(buf, position)
+ );
+ previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+
+ @Override
+ public void aggregate(
+ ByteBuffer buf,
+ int numRows,
+ int[] positions,
+ @Nullable int[] rows,
+ int positionOffset
+ )
+ {
+ long[] vector = selector.getLongVector();
+ boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
+ for (int i = 0; i < numRows; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows != null ? rows[i] : i;
+ if (nulls == null || !nulls[row]) {
+ VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
+ previous.add(vector[row]);
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+ }
+ }
+
+ @Nullable
+ @Override
+ public VarianceAggregatorCollector get(ByteBuffer buf, int position)
+ {
+ return VarianceBufferAggregator.getVarianceCollector(buf, position);
+ }
+
+ @Override
+ public void close()
+ {
+ // Nothing to close.
+ }
+}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregator.java
new file mode 100644
index 00000000000..1a7dfb09c25
--- /dev/null
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregator.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.vector.VectorObjectSelector;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+
+/**
+ * Vectorized implementation of {@link VarianceBufferAggregator} for {@link VarianceAggregatorCollector}.
+ */
+public class VarianceObjectVectorAggregator implements VectorAggregator
+{
+ private final VectorObjectSelector selector;
+
+ public VarianceObjectVectorAggregator(VectorObjectSelector selector)
+ {
+ this.selector = selector;
+ }
+
+ @Override
+ public void init(ByteBuffer buf, int position)
+ {
+ VarianceBufferAggregator.doInit(buf, position);
+ }
+
+ @Override
+ public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
+ {
+ VarianceAggregatorCollector[] vector = (VarianceAggregatorCollector[]) selector.getObjectVector();
+ VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
+ for (int i = startRow; i < endRow; i++) {
+ previous.fold(vector[i]);
+ }
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+
+ @Override
+ public void aggregate(
+ ByteBuffer buf,
+ int numRows,
+ int[] positions,
+ @Nullable int[] rows,
+ int positionOffset
+ )
+ {
+ VarianceAggregatorCollector[] vector = (VarianceAggregatorCollector[]) selector.getObjectVector();
+ for (int i = 0; i < numRows; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows != null ? rows[i] : i;
+ VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
+ previous.fold(vector[row]);
+ VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
+ }
+ }
+
+ @Nullable
+ @Override
+ public VarianceAggregatorCollector get(ByteBuffer buf, int position)
+ {
+ return VarianceBufferAggregator.getVarianceCollector(buf, position);
+ }
+
+ @Override
+ public void close()
+ {
+ // Nothing to close.
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryUnitTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryUnitTest.java
new file mode 100644
index 00000000000..25a51300216
--- /dev/null
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryUnitTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import nl.jqno.equalsverifier.EqualsVerifier;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.query.aggregation.Aggregator;
+import org.apache.druid.query.aggregation.BufferAggregator;
+import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.column.ColumnCapabilities;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Answers;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+
+@RunWith(MockitoJUnitRunner.class)
+public class VarianceAggregatorFactoryUnitTest extends InitializedNullHandlingTest
+{
+ private static final String NAME = "NAME";
+ private static final String FIELD_NAME = "FIELD_NAME";
+ private static final String DOUBLE = "double";
+ private static final String LONG = "long";
+ private static final String VARIANCE = "variance";
+ private static final String UNKNOWN = "unknown";
+
+ @Mock
+ private ColumnCapabilities capabilities;
+ @Mock
+ private VectorColumnSelectorFactory selectorFactory;
+ @Mock(answer = Answers.RETURNS_MOCKS)
+ private ColumnSelectorFactory metricFactory;
+
+ private VarianceAggregatorFactory target;
+
+ @Before
+ public void setup()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME);
+ }
+
+ @Test
+ public void factorizeVectorShouldReturnFloatVectorAggregator()
+ {
+ VectorAggregator agg = target.factorizeVector(selectorFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceFloatVectorAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeVectorForDoubleShouldReturnFloatVectorAggregator()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, DOUBLE);
+ VectorAggregator agg = target.factorizeVector(selectorFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceDoubleVectorAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeVectorForLongShouldReturnFloatVectorAggregator()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, LONG);
+ VectorAggregator agg = target.factorizeVector(selectorFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceLongVectorAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeVectorForVarianceShouldReturnObjectVectorAggregator()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, VARIANCE);
+ VectorAggregator agg = target.factorizeVector(selectorFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceObjectVectorAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeVectorForComplexShouldReturnObjectVectorAggregator()
+ {
+ mockType(ValueType.COMPLEX);
+ VectorAggregator agg = target.factorizeVector(selectorFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceObjectVectorAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeBufferedForComplexShouldReturnObjectVectorAggregator()
+ {
+ mockType(ValueType.COMPLEX);
+ BufferAggregator agg = target.factorizeBuffered(metricFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceBufferAggregator.ObjectVarianceAggregator.class, agg.getClass());
+ }
+
+ @Test
+ public void factorizeForComplexShouldReturnObjectVectorAggregator()
+ {
+ mockType(ValueType.COMPLEX);
+ Aggregator agg = target.factorize(metricFactory);
+ Assert.assertNotNull(agg);
+ Assert.assertEquals(VarianceAggregator.ObjectVarianceAggregator.class, agg.getClass());
+ }
+
+ @Test(expected = IAE.class)
+ public void factorizeVectorForUnknownColumnShouldThrowIAE()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, UNKNOWN);
+ target.factorizeVector(selectorFactory);
+ }
+
+ @Test(expected = IAE.class)
+ public void factorizeBufferedForUnknownColumnShouldThrowIAE()
+ {
+ target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, UNKNOWN);
+ target.factorizeBuffered(metricFactory);
+ }
+
+ @Test
+ public void equalsContract()
+ {
+ EqualsVerifier.forClass(VarianceAggregatorFactory.class)
+ .usingGetClass()
+ .verify();
+ }
+
+ private void mockType(ValueType type)
+ {
+ Mockito.doReturn(capabilities).when(selectorFactory).getColumnCapabilities(FIELD_NAME);
+ Mockito.doReturn(capabilities).when(metricFactory).getColumnCapabilities(FIELD_NAME);
+ Mockito.doReturn(type).when(capabilities).getType();
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregatorTest.java
new file mode 100644
index 00000000000..4204c2ac683
--- /dev/null
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregatorTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.segment.vector.VectorValueSelector;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ThreadLocalRandom;
+
+@RunWith(MockitoJUnitRunner.class)
+public class VarianceDoubleVectorAggregatorTest extends InitializedNullHandlingTest
+{
+ private static final int START_ROW = 1;
+ private static final int POSITION = 2;
+ private static final int UNINIT_POSITION = 512;
+ private static final double EPSILON = 1e-10;
+ private static final double[] VALUES = new double[]{7.8d, 11, 23.67, 60, 123};
+ private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
+
+ @Mock
+ private VectorValueSelector selector;
+ private ByteBuffer buf;
+
+ private VarianceDoubleVectorAggregator target;
+
+ @Before
+ public void setup()
+ {
+ byte[] randomBytes = new byte[1024];
+ ThreadLocalRandom.current().nextBytes(randomBytes);
+ buf = ByteBuffer.wrap(randomBytes);
+ Mockito.doReturn(VALUES).when(selector).getDoubleVector();
+ target = new VarianceDoubleVectorAggregator(selector);
+ clearBufferForPositions(0, POSITION);
+ }
+
+ @Test
+ public void initValueShouldInitZero()
+ {
+ target.init(buf, UNINIT_POSITION);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregate()
+ {
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(VALUES.length - START_ROW, collector.count);
+ Assert.assertEquals(217.67, collector.sum, EPSILON);
+ Assert.assertEquals(7565.211675, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateWithNulls()
+ {
+ mockNullsVector();
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(
+ VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
+ collector.count
+ );
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 217.67 : 134, collector.sum, EPSILON);
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 7565.211675 : 6272, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateBatchWithoutRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, null, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(1, collector.count);
+ Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(1, collector.count);
+ Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRowsAndNulls()
+ {
+ mockNullsVector();
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
+ Assert.assertEquals(isNull ? 0 : 1, collector.count);
+ Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void getShouldReturnAllZeros()
+ {
+ VarianceAggregatorCollector collector = target.get(buf, POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ private void clearBufferForPositions(int offset, int... positions)
+ {
+ for (int position : positions) {
+ VarianceBufferAggregator.doInit(buf, offset + position);
+ }
+ }
+
+ private void mockNullsVector()
+ {
+ if (!NullHandling.replaceWithDefault()) {
+ Mockito.doReturn(NULLS).when(selector).getNullVector();
+ }
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregatorTest.java
new file mode 100644
index 00000000000..ed2f0a3c8d5
--- /dev/null
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregatorTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.segment.vector.VectorValueSelector;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.ThreadLocalRandom;
+
+@RunWith(MockitoJUnitRunner.class)
+public class VarianceFloatVectorAggregatorTest extends InitializedNullHandlingTest
+{
+ private static final int START_ROW = 1;
+ private static final int POSITION = 2;
+ private static final int UNINIT_POSITION = 512;
+ private static final double EPSILON = 1e-8;
+ private static final float[] VALUES = new float[]{7.8F, 11, 23.67F, 60, 123};
+ private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
+
+ @Mock
+ private VectorValueSelector selector;
+ private ByteBuffer buf;
+
+ private VarianceFloatVectorAggregator target;
+
+ @Before
+ public void setup()
+ {
+ byte[] randomBytes = new byte[1024];
+ ThreadLocalRandom.current().nextBytes(randomBytes);
+ buf = ByteBuffer.wrap(randomBytes);
+ Mockito.doReturn(VALUES).when(selector).getFloatVector();
+ target = new VarianceFloatVectorAggregator(selector);
+ clearBufferForPositions(0, POSITION);
+ }
+
+ @Test
+ public void initValueShouldInitZero()
+ {
+ target.init(buf, UNINIT_POSITION);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregate()
+ {
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(VALUES.length - START_ROW, collector.count);
+ Assert.assertEquals(217.67000007, collector.sum, EPSILON);
+ Assert.assertEquals(7565.2116703, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateWithNulls()
+ {
+ mockNullsVector();
+ target.aggregate(buf, POSITION, START_ROW, VALUES.length);
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
+ Assert.assertEquals(
+ VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
+ collector.count
+ );
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 217.67000007 : 134, collector.sum, EPSILON);
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 7565.2116703 : 6272, collector.nvariance, EPSILON);
+ }
+
+ @Test
+ public void aggregateBatchWithoutRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, null, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(1, collector.count);
+ Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRows()
+ {
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ Assert.assertEquals(1, collector.count);
+ Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRowsAndNulls()
+ {
+ mockNullsVector();
+ int[] positions = new int[]{0, 43, 70};
+ int[] rows = new int[]{3, 2, 0};
+ int positionOffset = 2;
+ clearBufferForPositions(positionOffset, positions);
+ target.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
+ buf,
+ positions[i] + positionOffset
+ );
+ boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
+ Assert.assertEquals(isNull ? 0 : 1, collector.count);
+ Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ }
+
+ @Test
+ public void getShouldReturnAllZeros()
+ {
+ VarianceAggregatorCollector collector = target.get(buf, POSITION);
+ Assert.assertEquals(0, collector.count);
+ Assert.assertEquals(0, collector.sum, EPSILON);
+ Assert.assertEquals(0, collector.nvariance, EPSILON);
+ }
+ private void clearBufferForPositions(int offset, int... positions)
+ {
+ for (int position : positions) {
+ VarianceBufferAggregator.doInit(buf, offset + position);
+ }
+ }
+
+ private void mockNullsVector()
+ {
+ if (!NullHandling.replaceWithDefault()) {
+ Mockito.doReturn(NULLS).when(selector).getNullVector();
+ }
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java
index 7755f32f9d7..a91e6353c75 100644
--- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java
@@ -20,6 +20,7 @@
package org.apache.druid.query.aggregation.variance;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import org.apache.druid.data.input.Row;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.PeriodGranularity;
@@ -63,14 +64,12 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
private final QueryRunner runner;
private final GroupByQueryRunnerFactory factory;
private final String testName;
+ private final GroupByQuery.Builder queryBuilder;
@Parameterized.Parameters(name = "{0}")
public static Collection