From 0b4c897fbe046c109c8e63f51221b2955cd5fe39 Mon Sep 17 00:00:00 2001 From: Suneet Saldanha Date: Thu, 17 Sep 2020 15:05:40 -0700 Subject: [PATCH] Vectorized variance aggregators (#10390) * wip vectorize * close but not quite * faster * unit tests * fix complex types for variance --- benchmarks/pom.xml | 7 +- .../druid/benchmark/VarianceBenchmark.java | 89 +++++++++ .../variance/VarianceAggregatorCollector.java | 1 + .../variance/VarianceAggregatorFactory.java | 44 ++++- .../variance/VarianceBufferAggregator.java | 88 ++++++--- .../VarianceDoubleVectorAggregator.java | 113 +++++++++++ .../VarianceFloatVectorAggregator.java | 113 +++++++++++ .../VarianceLongVectorAggregator.java | 113 +++++++++++ .../VarianceObjectVectorAggregator.java | 88 +++++++++ .../VarianceAggregatorFactoryUnitTest.java | 156 ++++++++++++++++ .../VarianceDoubleVectorAggregatorTest.java | 176 ++++++++++++++++++ .../VarianceFloatVectorAggregatorTest.java | 176 ++++++++++++++++++ .../variance/VarianceGroupByQueryTest.java | 31 +-- .../VarianceLongVectorAggregatorTest.java | 176 ++++++++++++++++++ .../VarianceObjectVectorAggregatorTest.java | 137 ++++++++++++++ .../variance/VarianceTimeseriesQueryTest.java | 12 +- .../sql/VarianceSqlAggregatorTest.java | 168 +++++++++-------- 17 files changed, 1555 insertions(+), 133 deletions(-) create mode 100644 benchmarks/src/test/java/org/apache/druid/benchmark/VarianceBenchmark.java create mode 100644 extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregator.java create mode 100644 extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregator.java create mode 100644 extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregator.java create mode 100644 extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregator.java create mode 100644 extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryUnitTest.java create mode 100644 extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregatorTest.java create mode 100644 extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregatorTest.java create mode 100644 extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregatorTest.java create mode 100644 extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregatorTest.java diff --git a/benchmarks/pom.xml b/benchmarks/pom.xml index f00b93125c5..4d216855067 100644 --- a/benchmarks/pom.xml +++ b/benchmarks/pom.xml @@ -82,6 +82,11 @@ druid-histogram ${project.parent.version} + + org.apache.druid.extensions + druid-stats + ${project.parent.version} + org.apache.druid druid-core @@ -172,7 +177,7 @@ org.apache.druid.extensions druid-protobuf-extensions - 0.20.0-SNAPSHOT + ${project.parent.version} test diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/VarianceBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/VarianceBenchmark.java new file mode 100644 index 00000000000..85b7c4d86fe --- /dev/null +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/VarianceBenchmark.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.benchmark; + +import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 5) +public class VarianceBenchmark +{ + @Param({"128", "256", "512", "1024"}) + int vectorSize; + + private float[] randomValues; + + @Setup + public void setup() + { + randomValues = new float[vectorSize]; + Random r = ThreadLocalRandom.current(); + for (int i = 0; i < vectorSize; i++) { + randomValues[i] = r.nextFloat(); + } + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void collectVarianceOneByOne(Blackhole blackhole) + { + VarianceAggregatorCollector collector = new VarianceAggregatorCollector(); + for (float v : randomValues) { + collector.add(v); + } + blackhole.consume(collector); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void collectVarianceInBatch(Blackhole blackhole) + { + double sum = 0, nvariance = 0; + for (float v : randomValues) { + sum += v; + } + double mean = sum / randomValues.length; + for (float v : randomValues) { + nvariance += (v - mean) * (v - mean); + } + VarianceAggregatorCollector collector = new VarianceAggregatorCollector(randomValues.length, sum, nvariance); + blackhole.consume(collector); + } +} diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java index ce0edb04f41..6526a86aa86 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java @@ -76,6 +76,7 @@ public class VarianceAggregatorCollector if (other == null || other.count == 0) { return; } + if (this.count == 0) { this.nvariance = other.nvariance; this.count = other.count; diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java index e9b59b48cd6..2894c019588 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java @@ -22,6 +22,7 @@ package org.apache.druid.query.aggregation.variance; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.IAE; @@ -35,12 +36,15 @@ import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.NoopAggregator; import org.apache.druid.query.aggregation.NoopBufferAggregator; import org.apache.druid.query.aggregation.ObjectAggregateCombiner; +import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.cache.CacheKeyBuilder; +import org.apache.druid.segment.ColumnInspector; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.NilColumnValueSelector; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -83,7 +87,8 @@ public class VarianceAggregatorFactory extends AggregatorFactory this.inputType = inputType; } - public VarianceAggregatorFactory(String name, String fieldName) + @VisibleForTesting + VarianceAggregatorFactory(String name, String fieldName) { this(name, fieldName, null, null); } @@ -131,7 +136,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory return new VarianceAggregator.DoubleVarianceAggregator(selector); } else if (ValueType.LONG.name().equalsIgnoreCase(type)) { return new VarianceAggregator.LongVarianceAggregator(selector); - } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) { + } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) { return new VarianceAggregator.ObjectVarianceAggregator(selector); } throw new IAE( @@ -156,16 +161,42 @@ public class VarianceAggregatorFactory extends AggregatorFactory return new VarianceBufferAggregator.DoubleVarianceAggregator(selector); } else if (ValueType.LONG.name().equalsIgnoreCase(type)) { return new VarianceBufferAggregator.LongVarianceAggregator(selector); - } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) { + } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) { return new VarianceBufferAggregator.ObjectVarianceAggregator(selector); } throw new IAE( "Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s", fieldName, - inputType + type ); } + @Override + public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory) + { + final String type = getTypeString(selectorFactory); + if (ValueType.FLOAT.name().equalsIgnoreCase(type)) { + return new VarianceFloatVectorAggregator(selectorFactory.makeValueSelector(fieldName)); + } else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) { + return new VarianceDoubleVectorAggregator(selectorFactory.makeValueSelector(fieldName)); + } else if (ValueType.LONG.name().equalsIgnoreCase(type)) { + return new VarianceLongVectorAggregator(selectorFactory.makeValueSelector(fieldName)); + } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) { + return new VarianceObjectVectorAggregator(selectorFactory.makeObjectSelector(fieldName)); + } + throw new IAE( + "Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s", + fieldName, + type + ); + } + + @Override + public boolean canVectorize(ColumnInspector columnInspector) + { + return true; + } + @Override public Object combine(Object lhs, Object rhs) { @@ -340,11 +371,11 @@ public class VarianceAggregatorFactory extends AggregatorFactory return Objects.hash(fieldName, name, estimator, inputType, isVariancePop); } - private String getTypeString(ColumnSelectorFactory metricFactory) + private String getTypeString(ColumnInspector columnInspector) { String type = inputType; if (type == null) { - ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName); + ColumnCapabilities capabilities = columnInspector.getColumnCapabilities(fieldName); if (capabilities != null) { type = StringUtils.toLowerCase(capabilities.getType().name()); } else { @@ -353,5 +384,4 @@ public class VarianceAggregatorFactory extends AggregatorFactory } return type; } - } diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java index 51ec0b1de73..065ad2aa3e0 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java @@ -35,25 +35,19 @@ import java.nio.ByteBuffer; public abstract class VarianceBufferAggregator implements BufferAggregator { private static final int COUNT_OFFSET = 0; - private static final int SUM_OFFSET = Long.BYTES; + private static final int SUM_OFFSET = COUNT_OFFSET + Long.BYTES; private static final int NVARIANCE_OFFSET = SUM_OFFSET + Double.BYTES; @Override public void init(final ByteBuffer buf, final int position) { - buf.putLong(position + COUNT_OFFSET, 0) - .putDouble(position + SUM_OFFSET, 0) - .putDouble(position + NVARIANCE_OFFSET, 0); + doInit(buf, position); } @Override - public Object get(final ByteBuffer buf, final int position) + public VarianceAggregatorCollector get(final ByteBuffer buf, final int position) { - VarianceAggregatorCollector holder = new VarianceAggregatorCollector(); - holder.count = buf.getLong(position); - holder.sum = buf.getDouble(position + SUM_OFFSET); - holder.nvariance = buf.getDouble(position + NVARIANCE_OFFSET); - return holder; + return getVarianceCollector(buf, position); } @Override @@ -79,6 +73,51 @@ public abstract class VarianceBufferAggregator implements BufferAggregator { } + public static void doInit(ByteBuffer buf, int position) + { + buf.putLong(position + COUNT_OFFSET, 0) + .putDouble(position + SUM_OFFSET, 0) + .putDouble(position + NVARIANCE_OFFSET, 0); + } + + public static long getCount(ByteBuffer buf, int position) + { + return buf.getLong(position + COUNT_OFFSET); + } + + public static double getSum(ByteBuffer buf, int position) + { + return buf.getDouble(position + SUM_OFFSET); + } + + public static double getVariance(ByteBuffer buf, int position) + { + return buf.getDouble(position + NVARIANCE_OFFSET); + } + public static VarianceAggregatorCollector getVarianceCollector(ByteBuffer buf, int position) + { + return new VarianceAggregatorCollector( + getCount(buf, position), + getSum(buf, position), + getVariance(buf, position) + ); + } + + public static void writeNVariance(ByteBuffer buf, int position, long count, double sum, double nvariance) + { + buf.putLong(position + COUNT_OFFSET, count); + buf.putDouble(position + SUM_OFFSET, sum); + if (count > 1) { + buf.putDouble(position + NVARIANCE_OFFSET, nvariance); + } + } + + public static void writeCountAndSum(ByteBuffer buf, int position, long count, double sum) + { + buf.putLong(position + COUNT_OFFSET, count); + buf.putDouble(position + SUM_OFFSET, sum); + } + public static final class FloatVarianceAggregator extends VarianceBufferAggregator { private final boolean noNulls = NullHandling.replaceWithDefault(); @@ -94,10 +133,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator { if (noNulls || !selector.isNull()) { float v = selector.getFloat(); - long count = buf.getLong(position + COUNT_OFFSET) + 1; - double sum = buf.getDouble(position + SUM_OFFSET) + v; - buf.putLong(position, count); - buf.putDouble(position + SUM_OFFSET, sum); + long count = getCount(buf, position) + 1; + double sum = getSum(buf, position) + v; + writeCountAndSum(buf, position, count, sum); if (count > 1) { double t = count * v - sum; double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1)); @@ -128,10 +166,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator { if (noNulls || !selector.isNull()) { double v = selector.getDouble(); - long count = buf.getLong(position + COUNT_OFFSET) + 1; - double sum = buf.getDouble(position + SUM_OFFSET) + v; - buf.putLong(position, count); - buf.putDouble(position + SUM_OFFSET, sum); + long count = getCount(buf, position) + 1; + double sum = getSum(buf, position) + v; + writeCountAndSum(buf, position, count, sum); if (count > 1) { double t = count * v - sum; double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1)); @@ -162,10 +199,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator { if (noNulls || !selector.isNull()) { long v = selector.getLong(); - long count = buf.getLong(position + COUNT_OFFSET) + 1; - double sum = buf.getDouble(position + SUM_OFFSET) + v; - buf.putLong(position, count); - buf.putDouble(position + SUM_OFFSET, sum); + long count = getCount(buf, position) + 1; + double sum = getSum(buf, position) + v; + writeCountAndSum(buf, position, count, sum); if (count > 1) { double t = count * v - sum; double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1)); @@ -195,7 +231,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator { VarianceAggregatorCollector holder2 = (VarianceAggregatorCollector) selector.getObject(); Preconditions.checkState(holder2 != null); - long count = buf.getLong(position + COUNT_OFFSET); + long count = getCount(buf, position); if (count == 0) { buf.putLong(position, holder2.count); buf.putDouble(position + SUM_OFFSET, holder2.sum); @@ -203,7 +239,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator return; } - double sum = buf.getDouble(position + SUM_OFFSET); + double sum = getSum(buf, position); double nvariance = buf.getDouble(position + NVARIANCE_OFFSET); final double ratio = count / (double) holder2.count; @@ -213,9 +249,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator count += holder2.count; sum += holder2.sum; - buf.putLong(position, count); - buf.putDouble(position + SUM_OFFSET, sum); - buf.putDouble(position + NVARIANCE_OFFSET, nvariance); + writeNVariance(buf, position, count, sum, nvariance); } @Override diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregator.java new file mode 100644 index 00000000000..37c87399cf0 --- /dev/null +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregator.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.variance; + +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.segment.vector.VectorValueSelector; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; + +/** + * Vectorized implementation of {@link VarianceBufferAggregator} for doubles. + */ +public class VarianceDoubleVectorAggregator implements VectorAggregator +{ + private final VectorValueSelector selector; + private final boolean replaceWithDefault = NullHandling.replaceWithDefault(); + + public VarianceDoubleVectorAggregator(VectorValueSelector selector) + { + this.selector = selector; + } + + @Override + public void init(ByteBuffer buf, int position) + { + VarianceBufferAggregator.doInit(buf, position); + } + + @Override + public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) + { + double[] vector = selector.getDoubleVector(); + long count = 0; + double sum = 0, nvariance = 0; + boolean[] nulls = replaceWithDefault ? null : selector.getNullVector(); + for (int i = startRow; i < endRow; i++) { + if (nulls == null || !nulls[i]) { + count++; + sum += vector[i]; + } + } + double mean = sum / count; + if (count > 1) { + for (int i = startRow; i < endRow; i++) { + if (nulls == null || !nulls[i]) { + nvariance += (vector[i] - mean) * (vector[i] - mean); + } + } + } + + VarianceAggregatorCollector previous = new VarianceAggregatorCollector( + VarianceBufferAggregator.getCount(buf, position), + VarianceBufferAggregator.getSum(buf, position), + VarianceBufferAggregator.getVariance(buf, position) + ); + previous.fold(new VarianceAggregatorCollector(count, sum, nvariance)); + VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance); + } + + @Override + public void aggregate( + ByteBuffer buf, + int numRows, + int[] positions, + @Nullable int[] rows, + int positionOffset + ) + { + double[] vector = selector.getDoubleVector(); + boolean[] nulls = replaceWithDefault ? null : selector.getNullVector(); + for (int i = 0; i < numRows; i++) { + int position = positions[i] + positionOffset; + int row = rows != null ? rows[i] : i; + if (nulls == null || !nulls[row]) { + VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position); + previous.add(vector[row]); + VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance); + } + } + } + + @Nullable + @Override + public VarianceAggregatorCollector get(ByteBuffer buf, int position) + { + return VarianceBufferAggregator.getVarianceCollector(buf, position); + } + + @Override + public void close() + { + // Nothing to close. + } +} diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregator.java new file mode 100644 index 00000000000..957926e9781 --- /dev/null +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregator.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.variance; + +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.segment.vector.VectorValueSelector; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; + +/** + * Vectorized implementation of {@link VarianceBufferAggregator} for floats. + */ +public class VarianceFloatVectorAggregator implements VectorAggregator +{ + private final VectorValueSelector selector; + private final boolean replaceWithDefault = NullHandling.replaceWithDefault(); + + public VarianceFloatVectorAggregator(VectorValueSelector selector) + { + this.selector = selector; + } + + @Override + public void init(ByteBuffer buf, int position) + { + VarianceBufferAggregator.doInit(buf, position); + } + + @Override + public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) + { + float[] vector = selector.getFloatVector(); + long count = 0; + double sum = 0, nvariance = 0; + boolean[] nulls = replaceWithDefault ? null : selector.getNullVector(); + for (int i = startRow; i < endRow; i++) { + if (nulls == null || !nulls[i]) { + count++; + sum += vector[i]; + } + } + double mean = sum / count; + if (count > 1) { + for (int i = startRow; i < endRow; i++) { + if (nulls == null || !nulls[i]) { + nvariance += (vector[i] - mean) * (vector[i] - mean); + } + } + } + + VarianceAggregatorCollector previous = new VarianceAggregatorCollector( + VarianceBufferAggregator.getCount(buf, position), + VarianceBufferAggregator.getSum(buf, position), + VarianceBufferAggregator.getVariance(buf, position) + ); + previous.fold(new VarianceAggregatorCollector(count, sum, nvariance)); + VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance); + } + + @Override + public void aggregate( + ByteBuffer buf, + int numRows, + int[] positions, + @Nullable int[] rows, + int positionOffset + ) + { + float[] vector = selector.getFloatVector(); + boolean[] nulls = replaceWithDefault ? null : selector.getNullVector(); + for (int i = 0; i < numRows; i++) { + int position = positions[i] + positionOffset; + int row = rows != null ? rows[i] : i; + if (nulls == null || !nulls[row]) { + VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position); + previous.add(vector[row]); + VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance); + } + } + } + + @Nullable + @Override + public VarianceAggregatorCollector get(ByteBuffer buf, int position) + { + return VarianceBufferAggregator.getVarianceCollector(buf, position); + } + + @Override + public void close() + { + // Nothing to close. + } +} diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregator.java new file mode 100644 index 00000000000..69941b658b6 --- /dev/null +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregator.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.variance; + +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.segment.vector.VectorValueSelector; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; + +/** + * Vectorized implementation of {@link VarianceBufferAggregator} for longs. + */ +public class VarianceLongVectorAggregator implements VectorAggregator +{ + private final VectorValueSelector selector; + private final boolean replaceWithDefault = NullHandling.replaceWithDefault(); + + public VarianceLongVectorAggregator(VectorValueSelector selector) + { + this.selector = selector; + } + + @Override + public void init(ByteBuffer buf, int position) + { + VarianceBufferAggregator.doInit(buf, position); + } + + @Override + public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) + { + long[] vector = selector.getLongVector(); + long count = 0; + double sum = 0, nvariance = 0; + boolean[] nulls = replaceWithDefault ? null : selector.getNullVector(); + for (int i = startRow; i < endRow; i++) { + if (nulls == null || !nulls[i]) { + count++; + sum += vector[i]; + } + } + double mean = sum / count; + if (count > 1) { + for (int i = startRow; i < endRow; i++) { + if (nulls == null || !nulls[i]) { + nvariance += (vector[i] - mean) * (vector[i] - mean); + } + } + } + + VarianceAggregatorCollector previous = new VarianceAggregatorCollector( + VarianceBufferAggregator.getCount(buf, position), + VarianceBufferAggregator.getSum(buf, position), + VarianceBufferAggregator.getVariance(buf, position) + ); + previous.fold(new VarianceAggregatorCollector(count, sum, nvariance)); + VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance); + } + + @Override + public void aggregate( + ByteBuffer buf, + int numRows, + int[] positions, + @Nullable int[] rows, + int positionOffset + ) + { + long[] vector = selector.getLongVector(); + boolean[] nulls = replaceWithDefault ? null : selector.getNullVector(); + for (int i = 0; i < numRows; i++) { + int position = positions[i] + positionOffset; + int row = rows != null ? rows[i] : i; + if (nulls == null || !nulls[row]) { + VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position); + previous.add(vector[row]); + VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance); + } + } + } + + @Nullable + @Override + public VarianceAggregatorCollector get(ByteBuffer buf, int position) + { + return VarianceBufferAggregator.getVarianceCollector(buf, position); + } + + @Override + public void close() + { + // Nothing to close. + } +} diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregator.java new file mode 100644 index 00000000000..1a7dfb09c25 --- /dev/null +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregator.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.variance; + +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.segment.vector.VectorObjectSelector; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; + +/** + * Vectorized implementation of {@link VarianceBufferAggregator} for {@link VarianceAggregatorCollector}. + */ +public class VarianceObjectVectorAggregator implements VectorAggregator +{ + private final VectorObjectSelector selector; + + public VarianceObjectVectorAggregator(VectorObjectSelector selector) + { + this.selector = selector; + } + + @Override + public void init(ByteBuffer buf, int position) + { + VarianceBufferAggregator.doInit(buf, position); + } + + @Override + public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) + { + VarianceAggregatorCollector[] vector = (VarianceAggregatorCollector[]) selector.getObjectVector(); + VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position); + for (int i = startRow; i < endRow; i++) { + previous.fold(vector[i]); + } + VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance); + } + + @Override + public void aggregate( + ByteBuffer buf, + int numRows, + int[] positions, + @Nullable int[] rows, + int positionOffset + ) + { + VarianceAggregatorCollector[] vector = (VarianceAggregatorCollector[]) selector.getObjectVector(); + for (int i = 0; i < numRows; i++) { + int position = positions[i] + positionOffset; + int row = rows != null ? rows[i] : i; + VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position); + previous.fold(vector[row]); + VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance); + } + } + + @Nullable + @Override + public VarianceAggregatorCollector get(ByteBuffer buf, int position) + { + return VarianceBufferAggregator.getVarianceCollector(buf, position); + } + + @Override + public void close() + { + // Nothing to close. + } +} diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryUnitTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryUnitTest.java new file mode 100644 index 00000000000..25a51300216 --- /dev/null +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryUnitTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.variance; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.query.aggregation.Aggregator; +import org.apache.druid.query.aggregation.BufferAggregator; +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class VarianceAggregatorFactoryUnitTest extends InitializedNullHandlingTest +{ + private static final String NAME = "NAME"; + private static final String FIELD_NAME = "FIELD_NAME"; + private static final String DOUBLE = "double"; + private static final String LONG = "long"; + private static final String VARIANCE = "variance"; + private static final String UNKNOWN = "unknown"; + + @Mock + private ColumnCapabilities capabilities; + @Mock + private VectorColumnSelectorFactory selectorFactory; + @Mock(answer = Answers.RETURNS_MOCKS) + private ColumnSelectorFactory metricFactory; + + private VarianceAggregatorFactory target; + + @Before + public void setup() + { + target = new VarianceAggregatorFactory(NAME, FIELD_NAME); + } + + @Test + public void factorizeVectorShouldReturnFloatVectorAggregator() + { + VectorAggregator agg = target.factorizeVector(selectorFactory); + Assert.assertNotNull(agg); + Assert.assertEquals(VarianceFloatVectorAggregator.class, agg.getClass()); + } + + @Test + public void factorizeVectorForDoubleShouldReturnFloatVectorAggregator() + { + target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, DOUBLE); + VectorAggregator agg = target.factorizeVector(selectorFactory); + Assert.assertNotNull(agg); + Assert.assertEquals(VarianceDoubleVectorAggregator.class, agg.getClass()); + } + + @Test + public void factorizeVectorForLongShouldReturnFloatVectorAggregator() + { + target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, LONG); + VectorAggregator agg = target.factorizeVector(selectorFactory); + Assert.assertNotNull(agg); + Assert.assertEquals(VarianceLongVectorAggregator.class, agg.getClass()); + } + + @Test + public void factorizeVectorForVarianceShouldReturnObjectVectorAggregator() + { + target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, VARIANCE); + VectorAggregator agg = target.factorizeVector(selectorFactory); + Assert.assertNotNull(agg); + Assert.assertEquals(VarianceObjectVectorAggregator.class, agg.getClass()); + } + + @Test + public void factorizeVectorForComplexShouldReturnObjectVectorAggregator() + { + mockType(ValueType.COMPLEX); + VectorAggregator agg = target.factorizeVector(selectorFactory); + Assert.assertNotNull(agg); + Assert.assertEquals(VarianceObjectVectorAggregator.class, agg.getClass()); + } + + @Test + public void factorizeBufferedForComplexShouldReturnObjectVectorAggregator() + { + mockType(ValueType.COMPLEX); + BufferAggregator agg = target.factorizeBuffered(metricFactory); + Assert.assertNotNull(agg); + Assert.assertEquals(VarianceBufferAggregator.ObjectVarianceAggregator.class, agg.getClass()); + } + + @Test + public void factorizeForComplexShouldReturnObjectVectorAggregator() + { + mockType(ValueType.COMPLEX); + Aggregator agg = target.factorize(metricFactory); + Assert.assertNotNull(agg); + Assert.assertEquals(VarianceAggregator.ObjectVarianceAggregator.class, agg.getClass()); + } + + @Test(expected = IAE.class) + public void factorizeVectorForUnknownColumnShouldThrowIAE() + { + target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, UNKNOWN); + target.factorizeVector(selectorFactory); + } + + @Test(expected = IAE.class) + public void factorizeBufferedForUnknownColumnShouldThrowIAE() + { + target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, UNKNOWN); + target.factorizeBuffered(metricFactory); + } + + @Test + public void equalsContract() + { + EqualsVerifier.forClass(VarianceAggregatorFactory.class) + .usingGetClass() + .verify(); + } + + private void mockType(ValueType type) + { + Mockito.doReturn(capabilities).when(selectorFactory).getColumnCapabilities(FIELD_NAME); + Mockito.doReturn(capabilities).when(metricFactory).getColumnCapabilities(FIELD_NAME); + Mockito.doReturn(type).when(capabilities).getType(); + } +} diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregatorTest.java new file mode 100644 index 00000000000..4204c2ac683 --- /dev/null +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceDoubleVectorAggregatorTest.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.variance; + +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import java.nio.ByteBuffer; +import java.util.concurrent.ThreadLocalRandom; + +@RunWith(MockitoJUnitRunner.class) +public class VarianceDoubleVectorAggregatorTest extends InitializedNullHandlingTest +{ + private static final int START_ROW = 1; + private static final int POSITION = 2; + private static final int UNINIT_POSITION = 512; + private static final double EPSILON = 1e-10; + private static final double[] VALUES = new double[]{7.8d, 11, 23.67, 60, 123}; + private static final boolean[] NULLS = new boolean[]{false, false, true, true, false}; + + @Mock + private VectorValueSelector selector; + private ByteBuffer buf; + + private VarianceDoubleVectorAggregator target; + + @Before + public void setup() + { + byte[] randomBytes = new byte[1024]; + ThreadLocalRandom.current().nextBytes(randomBytes); + buf = ByteBuffer.wrap(randomBytes); + Mockito.doReturn(VALUES).when(selector).getDoubleVector(); + target = new VarianceDoubleVectorAggregator(selector); + clearBufferForPositions(0, POSITION); + } + + @Test + public void initValueShouldInitZero() + { + target.init(buf, UNINIT_POSITION); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION); + Assert.assertEquals(0, collector.count); + Assert.assertEquals(0, collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + + @Test + public void aggregate() + { + target.aggregate(buf, POSITION, START_ROW, VALUES.length); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION); + Assert.assertEquals(VALUES.length - START_ROW, collector.count); + Assert.assertEquals(217.67, collector.sum, EPSILON); + Assert.assertEquals(7565.211675, collector.nvariance, EPSILON); + } + + @Test + public void aggregateWithNulls() + { + mockNullsVector(); + target.aggregate(buf, POSITION, START_ROW, VALUES.length); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION); + Assert.assertEquals( + VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2), + collector.count + ); + Assert.assertEquals(NullHandling.replaceWithDefault() ? 217.67 : 134, collector.sum, EPSILON); + Assert.assertEquals(NullHandling.replaceWithDefault() ? 7565.211675 : 6272, collector.nvariance, EPSILON); + } + + @Test + public void aggregateBatchWithoutRows() + { + int[] positions = new int[]{0, 43, 70}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, null, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + Assert.assertEquals(1, collector.count); + Assert.assertEquals(VALUES[i], collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + } + + @Test + public void aggregateBatchWithRows() + { + int[] positions = new int[]{0, 43, 70}; + int[] rows = new int[]{3, 2, 0}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, rows, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + Assert.assertEquals(1, collector.count); + Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + } + + @Test + public void aggregateBatchWithRowsAndNulls() + { + mockNullsVector(); + int[] positions = new int[]{0, 43, 70}; + int[] rows = new int[]{3, 2, 0}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, rows, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]]; + Assert.assertEquals(isNull ? 0 : 1, collector.count); + Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + } + + @Test + public void getShouldReturnAllZeros() + { + VarianceAggregatorCollector collector = target.get(buf, POSITION); + Assert.assertEquals(0, collector.count); + Assert.assertEquals(0, collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + private void clearBufferForPositions(int offset, int... positions) + { + for (int position : positions) { + VarianceBufferAggregator.doInit(buf, offset + position); + } + } + + private void mockNullsVector() + { + if (!NullHandling.replaceWithDefault()) { + Mockito.doReturn(NULLS).when(selector).getNullVector(); + } + } +} diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregatorTest.java new file mode 100644 index 00000000000..ed2f0a3c8d5 --- /dev/null +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceFloatVectorAggregatorTest.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.variance; + +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import java.nio.ByteBuffer; +import java.util.concurrent.ThreadLocalRandom; + +@RunWith(MockitoJUnitRunner.class) +public class VarianceFloatVectorAggregatorTest extends InitializedNullHandlingTest +{ + private static final int START_ROW = 1; + private static final int POSITION = 2; + private static final int UNINIT_POSITION = 512; + private static final double EPSILON = 1e-8; + private static final float[] VALUES = new float[]{7.8F, 11, 23.67F, 60, 123}; + private static final boolean[] NULLS = new boolean[]{false, false, true, true, false}; + + @Mock + private VectorValueSelector selector; + private ByteBuffer buf; + + private VarianceFloatVectorAggregator target; + + @Before + public void setup() + { + byte[] randomBytes = new byte[1024]; + ThreadLocalRandom.current().nextBytes(randomBytes); + buf = ByteBuffer.wrap(randomBytes); + Mockito.doReturn(VALUES).when(selector).getFloatVector(); + target = new VarianceFloatVectorAggregator(selector); + clearBufferForPositions(0, POSITION); + } + + @Test + public void initValueShouldInitZero() + { + target.init(buf, UNINIT_POSITION); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION); + Assert.assertEquals(0, collector.count); + Assert.assertEquals(0, collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + + @Test + public void aggregate() + { + target.aggregate(buf, POSITION, START_ROW, VALUES.length); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION); + Assert.assertEquals(VALUES.length - START_ROW, collector.count); + Assert.assertEquals(217.67000007, collector.sum, EPSILON); + Assert.assertEquals(7565.2116703, collector.nvariance, EPSILON); + } + + @Test + public void aggregateWithNulls() + { + mockNullsVector(); + target.aggregate(buf, POSITION, START_ROW, VALUES.length); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION); + Assert.assertEquals( + VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2), + collector.count + ); + Assert.assertEquals(NullHandling.replaceWithDefault() ? 217.67000007 : 134, collector.sum, EPSILON); + Assert.assertEquals(NullHandling.replaceWithDefault() ? 7565.2116703 : 6272, collector.nvariance, EPSILON); + } + + @Test + public void aggregateBatchWithoutRows() + { + int[] positions = new int[]{0, 43, 70}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, null, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + Assert.assertEquals(1, collector.count); + Assert.assertEquals(VALUES[i], collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + } + + @Test + public void aggregateBatchWithRows() + { + int[] positions = new int[]{0, 43, 70}; + int[] rows = new int[]{3, 2, 0}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, rows, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + Assert.assertEquals(1, collector.count); + Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + } + + @Test + public void aggregateBatchWithRowsAndNulls() + { + mockNullsVector(); + int[] positions = new int[]{0, 43, 70}; + int[] rows = new int[]{3, 2, 0}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, rows, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]]; + Assert.assertEquals(isNull ? 0 : 1, collector.count); + Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + } + + @Test + public void getShouldReturnAllZeros() + { + VarianceAggregatorCollector collector = target.get(buf, POSITION); + Assert.assertEquals(0, collector.count); + Assert.assertEquals(0, collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + private void clearBufferForPositions(int offset, int... positions) + { + for (int position : positions) { + VarianceBufferAggregator.doInit(buf, offset + position); + } + } + + private void mockNullsVector() + { + if (!NullHandling.replaceWithDefault()) { + Mockito.doReturn(NULLS).when(selector).getNullVector(); + } + } +} diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java index 7755f32f9d7..a91e6353c75 100644 --- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceGroupByQueryTest.java @@ -20,6 +20,7 @@ package org.apache.druid.query.aggregation.variance; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import org.apache.druid.data.input.Row; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.granularity.PeriodGranularity; @@ -63,14 +64,12 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest private final QueryRunner runner; private final GroupByQueryRunnerFactory factory; private final String testName; + private final GroupByQuery.Builder queryBuilder; @Parameterized.Parameters(name = "{0}") public static Collection constructorFeeder() { - // Use GroupByQueryRunnerTest's constructorFeeder, but remove vectorized tests, since this aggregator - // can't vectorize yet. return GroupByQueryRunnerTest.constructorFeeder().stream() - .filter(constructor -> !((boolean) constructor[4]) /* !vectorize */) .map( constructor -> new Object[]{ @@ -94,13 +93,14 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest this.config = config; this.factory = factory; this.runner = factory.mergeRunners(Execs.directExecutor(), ImmutableList.of(runner)); + this.queryBuilder = GroupByQuery.builder() + .setContext(ImmutableMap.of("vectorize", config.isVectorize())); } @Test public void testGroupByVarianceOnly() { - GroupByQuery query = GroupByQuery - .builder() + GroupByQuery query = queryBuilder .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) .setDimensions(new DefaultDimensionSpec("quality", "alias")) @@ -141,8 +141,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest @Test public void testGroupBy() { - GroupByQuery query = GroupByQuery - .builder() + GroupByQuery query = queryBuilder .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) .setDimensions(new DefaultDimensionSpec("quality", "alias")) @@ -191,8 +190,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest new String[]{"alias", "rows", "index", "index_var", "index_stddev"} ); - GroupByQuery query = GroupByQuery - .builder() + GroupByQuery query = queryBuilder .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setInterval("2011-04-02/2011-04-04") .setDimensions(new DefaultDimensionSpec("quality", "alias")) @@ -244,8 +242,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest public void testGroupByZtestPostAgg() { // test postaggs from 'teststats' package in here since we've already gone to the trouble of setting up the test - GroupByQuery query = GroupByQuery - .builder() + GroupByQuery query = queryBuilder .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) .setDimensions(new DefaultDimensionSpec("quality", "alias")) @@ -286,8 +283,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest public void testGroupByTestPvalueZscorePostAgg() { // test postaggs from 'teststats' package in here since we've already gone to the trouble of setting up the test - GroupByQuery query = GroupByQuery - .builder() + GroupByQuery query = queryBuilder .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) .setDimensions(new DefaultDimensionSpec("quality", "alias")) @@ -308,7 +304,14 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest .build(); VarianceTestHelper.RowBuilder builder = - new VarianceTestHelper.RowBuilder(new String[]{"alias", "rows", "idx", "index_stddev", "index_var", "pvalueZscore"}); + new VarianceTestHelper.RowBuilder(new String[]{ + "alias", + "rows", + "idx", + "index_stddev", + "index_var", + "pvalueZscore" + }); List expectedResults = builder .add("2011-04-01", "automotive", 1L, 135.0, 0.0, 0.0, 1.0) diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregatorTest.java new file mode 100644 index 00000000000..d47bf5cd3f3 --- /dev/null +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceLongVectorAggregatorTest.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.variance; + +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.segment.vector.VectorValueSelector; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import java.nio.ByteBuffer; +import java.util.concurrent.ThreadLocalRandom; + +@RunWith(MockitoJUnitRunner.class) +public class VarianceLongVectorAggregatorTest extends InitializedNullHandlingTest +{ + private static final int START_ROW = 1; + private static final int POSITION = 2; + private static final int UNINIT_POSITION = 512; + private static final double EPSILON = 1e-10; + private static final long[] VALUES = new long[]{7, 11, 23, 60, 123}; + private static final boolean[] NULLS = new boolean[]{false, false, true, true, false}; + + @Mock + private VectorValueSelector selector; + private ByteBuffer buf; + + private VarianceLongVectorAggregator target; + + @Before + public void setup() + { + byte[] randomBytes = new byte[1024]; + ThreadLocalRandom.current().nextBytes(randomBytes); + buf = ByteBuffer.wrap(randomBytes); + Mockito.doReturn(VALUES).when(selector).getLongVector(); + target = new VarianceLongVectorAggregator(selector); + clearBufferForPositions(0, POSITION); + } + + @Test + public void initValueShouldInitZero() + { + target.init(buf, UNINIT_POSITION); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION); + Assert.assertEquals(0, collector.count); + Assert.assertEquals(0, collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + + @Test + public void aggregate() + { + target.aggregate(buf, POSITION, START_ROW, VALUES.length); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION); + Assert.assertEquals(VALUES.length - START_ROW, collector.count); + Assert.assertEquals(217, collector.sum, EPSILON); + Assert.assertEquals(7606.75, collector.nvariance, EPSILON); + } + + @Test + public void aggregateWithNulls() + { + mockNullsVector(); + target.aggregate(buf, POSITION, START_ROW, VALUES.length); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION); + Assert.assertEquals( + VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2), + collector.count + ); + Assert.assertEquals(NullHandling.replaceWithDefault() ? 217 : 134, collector.sum, EPSILON); + Assert.assertEquals(NullHandling.replaceWithDefault() ? 7606.75 : 6272, collector.nvariance, EPSILON); + } + + @Test + public void aggregateBatchWithoutRows() + { + int[] positions = new int[]{0, 43, 70}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, null, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + Assert.assertEquals(1, collector.count); + Assert.assertEquals(VALUES[i], collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + } + + @Test + public void aggregateBatchWithRows() + { + int[] positions = new int[]{0, 43, 70}; + int[] rows = new int[]{3, 2, 0}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, rows, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + Assert.assertEquals(1, collector.count); + Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + } + + @Test + public void aggregateBatchWithRowsAndNulls() + { + mockNullsVector(); + int[] positions = new int[]{0, 43, 70}; + int[] rows = new int[]{3, 2, 0}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, rows, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]]; + Assert.assertEquals(isNull ? 0 : 1, collector.count); + Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + } + + @Test + public void getShouldReturnAllZeros() + { + VarianceAggregatorCollector collector = target.get(buf, POSITION); + Assert.assertEquals(0, collector.count); + Assert.assertEquals(0, collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + private void clearBufferForPositions(int offset, int... positions) + { + for (int position : positions) { + VarianceBufferAggregator.doInit(buf, offset + position); + } + } + + private void mockNullsVector() + { + if (!NullHandling.replaceWithDefault()) { + Mockito.doReturn(NULLS).when(selector).getNullVector(); + } + } +} diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregatorTest.java new file mode 100644 index 00000000000..0e6694ae099 --- /dev/null +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceObjectVectorAggregatorTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.variance; + +import org.apache.druid.segment.vector.VectorObjectSelector; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import java.nio.ByteBuffer; +import java.util.concurrent.ThreadLocalRandom; + +@RunWith(MockitoJUnitRunner.class) +public class VarianceObjectVectorAggregatorTest extends InitializedNullHandlingTest +{ + private static final int START_ROW = 1; + private static final int POSITION = 2; + private static final int UNINIT_POSITION = 512; + private static final double EPSILON = 1e-10; + private static final VarianceAggregatorCollector[] VALUES = new VarianceAggregatorCollector[]{ + new VarianceAggregatorCollector(1, 7.8, 0), + new VarianceAggregatorCollector(1, 11, 0), + new VarianceAggregatorCollector(1, 23.67, 0), + null, + new VarianceAggregatorCollector(2, 183, 1984.5) + }; + private static final boolean[] NULLS = new boolean[]{false, false, true, true, false}; + + @Mock + private VectorObjectSelector selector; + private ByteBuffer buf; + + private VarianceObjectVectorAggregator target; + + @Before + public void setup() + { + byte[] randomBytes = new byte[1024]; + ThreadLocalRandom.current().nextBytes(randomBytes); + buf = ByteBuffer.wrap(randomBytes); + Mockito.doReturn(VALUES).when(selector).getObjectVector(); + target = new VarianceObjectVectorAggregator(selector); + clearBufferForPositions(0, POSITION); + } + + @Test + public void initValueShouldInitZero() + { + target.init(buf, UNINIT_POSITION); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION); + Assert.assertEquals(0, collector.count); + Assert.assertEquals(0, collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + + @Test + public void aggregate() + { + target.aggregate(buf, POSITION, START_ROW, VALUES.length); + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION); + Assert.assertEquals(4, collector.count); + Assert.assertEquals(217.67, collector.sum, EPSILON); + Assert.assertEquals(7565.211675, collector.nvariance, EPSILON); + } + + @Test + public void aggregateBatchWithoutRows() + { + int[] positions = new int[]{0, 43, 70}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, null, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + Assert.assertEquals(VALUES[i], collector); + } + } + + @Test + public void aggregateBatchWithRows() + { + int[] positions = new int[]{0, 43, 70}; + int[] rows = new int[]{3, 2, 0}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, rows, positionOffset); + for (int i = 0; i < positions.length; i++) { + VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector( + buf, + positions[i] + positionOffset + ); + VarianceAggregatorCollector expectedCollector = VALUES[rows[i]]; + Assert.assertEquals(expectedCollector == null ? new VarianceAggregatorCollector() : expectedCollector, collector); + } + } + + @Test + public void getShouldReturnAllZeros() + { + VarianceAggregatorCollector collector = target.get(buf, POSITION); + Assert.assertEquals(0, collector.count); + Assert.assertEquals(0, collector.sum, EPSILON); + Assert.assertEquals(0, collector.nvariance, EPSILON); + } + + private void clearBufferForPositions(int offset, int... positions) + { + for (int position : positions) { + VarianceBufferAggregator.doInit(buf, offset + position); + } + } +} diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceTimeseriesQueryTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceTimeseriesQueryTest.java index 9c52961f7c6..fd28b386f0c 100644 --- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceTimeseriesQueryTest.java +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceTimeseriesQueryTest.java @@ -19,6 +19,7 @@ package org.apache.druid.query.aggregation.variance; +import com.google.common.collect.ImmutableMap; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.query.Druids; import org.apache.druid.query.QueryPlus; @@ -46,31 +47,32 @@ public class VarianceTimeseriesQueryTest extends InitializedNullHandlingTest @Parameterized.Parameters(name = "{0}:descending={1}") public static Iterable constructorFeeder() { - // Use TimeseriesQueryRunnerTest's constructorFeeder, but remove vectorized tests, since this aggregator - // can't vectorize yet. return StreamSupport.stream(TimeseriesQueryRunnerTest.constructorFeeder().spliterator(), false) - .filter(constructor -> !((boolean) constructor[2]) /* !vectorize */) - .map(constructor -> new Object[]{constructor[0], constructor[1], constructor[3]}) + .map(constructor -> new Object[]{constructor[0], constructor[1], constructor[2], constructor[3]}) .collect(Collectors.toList()); } private final QueryRunner runner; private final boolean descending; + private final Druids.TimeseriesQueryBuilder queryBuilder; public VarianceTimeseriesQueryTest( QueryRunner runner, boolean descending, + boolean vectorize, List aggregatorFactories ) { this.runner = runner; this.descending = descending; + this.queryBuilder = Druids.newTimeseriesQueryBuilder() + .context(ImmutableMap.of("vectorize", vectorize ? "force" : "false")); } @Test public void testTimeseriesWithNullFilterOnNonExistentDimension() { - TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() + TimeseriesQuery query = queryBuilder .dataSource(QueryRunnerTestHelper.DATA_SOURCE) .granularity(QueryRunnerTestHelper.DAY_GRAN) .filters("bobby", null) diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java index cfb945b99cc..344bbb98fc1 100644 --- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java @@ -285,16 +285,16 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest Assert.assertEquals( Druids.newTimeseriesQueryBuilder() - .dataSource(CalciteTests.DATASOURCE3) - .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) - .granularity(Granularities.ALL) - .aggregators( - ImmutableList.of( - new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"), - new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"), - new VarianceAggregatorFactory("a2:agg", "l1", "population", "long") - ) - ) + .dataSource(CalciteTests.DATASOURCE3) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(Granularities.ALL) + .aggregators( + ImmutableList.of( + new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"), + new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"), + new VarianceAggregatorFactory("a2:agg", "l1", "population", "long") + ) + ) .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT) .build(), Iterables.getOnlyElement(queryLogHook.getRecordedQueries()) @@ -335,22 +335,22 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest holder1.getVariance(false), holder2.getVariance(false).floatValue(), holder3.getVariance(false).longValue(), - } + } ); assertResultsEquals(expectedResults, results); Assert.assertEquals( Druids.newTimeseriesQueryBuilder() - .dataSource(CalciteTests.DATASOURCE3) - .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) - .granularity(Granularities.ALL) - .aggregators( - ImmutableList.of( - new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"), - new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"), - new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long") - ) - ) + .dataSource(CalciteTests.DATASOURCE3) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(Granularities.ALL) + .aggregators( + ImmutableList.of( + new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"), + new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"), + new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long") + ) + ) .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT) .build(), Iterables.getOnlyElement(queryLogHook.getRecordedQueries()) @@ -391,28 +391,29 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest Math.sqrt(holder1.getVariance(true)), (float) Math.sqrt(holder2.getVariance(true)), (long) Math.sqrt(holder3.getVariance(true)), - } + } ); assertResultsEquals(expectedResults, results); Assert.assertEquals( Druids.newTimeseriesQueryBuilder() - .dataSource(CalciteTests.DATASOURCE3) - .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) - .granularity(Granularities.ALL) - .aggregators( - ImmutableList.of( - new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"), - new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"), - new VarianceAggregatorFactory("a2:agg", "l1", "population", "long") - ) - ) - .postAggregators( - ImmutableList.of( - new StandardDeviationPostAggregator("a0", "a0:agg", "population"), - new StandardDeviationPostAggregator("a1", "a1:agg", "population"), - new StandardDeviationPostAggregator("a2", "a2:agg", "population")) - ) + .dataSource(CalciteTests.DATASOURCE3) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(Granularities.ALL) + .aggregators( + ImmutableList.of( + new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"), + new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"), + new VarianceAggregatorFactory("a2:agg", "l1", "population", "long") + ) + ) + .postAggregators( + ImmutableList.of( + new StandardDeviationPostAggregator("a0", "a0:agg", "population"), + new StandardDeviationPostAggregator("a1", "a1:agg", "population"), + new StandardDeviationPostAggregator("a2", "a2:agg", "population") + ) + ) .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT) .build(), Iterables.getOnlyElement(queryLogHook.getRecordedQueries()) @@ -453,7 +454,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest Math.sqrt(holder1.getVariance(false)), (float) Math.sqrt(holder2.getVariance(false)), (long) Math.sqrt(holder3.getVariance(false)), - } + } ); assertResultsEquals(expectedResults, results); @@ -464,9 +465,9 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest .granularity(Granularities.ALL) .aggregators( ImmutableList.of( - new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"), - new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"), - new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long") + new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"), + new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"), + new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long") ) ) .postAggregators( @@ -514,7 +515,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest Math.sqrt(holder1.getVariance(false)), (float) Math.sqrt(holder2.getVariance(false)), (long) Math.sqrt(holder3.getVariance(false)), - } + } ); assertResultsEquals(expectedResults, results); @@ -530,9 +531,9 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest ) .aggregators( ImmutableList.of( - new VarianceAggregatorFactory("a0:agg", "v0", "sample", "double"), - new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"), - new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long") + new VarianceAggregatorFactory("a0:agg", "v0", "sample", "double"), + new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"), + new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long") ) ) .postAggregators( @@ -560,41 +561,41 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest authenticationResult ).toList(); List expectedResults = NullHandling.sqlCompatible() - ? ImmutableList.of( - new Object[] {"a", 0f}, - new Object[] {null, 0f}, - new Object[] {"", 0f}, - new Object[] {"abc", null} + ? ImmutableList.of( + new Object[]{"a", 0f}, + new Object[]{null, 0f}, + new Object[]{"", 0f}, + new Object[]{"abc", null} ) : ImmutableList.of( - new Object[] {"a", 0.5f}, - new Object[] {"", 0.0033333334f}, - new Object[] {"abc", 0f} + new Object[]{"a", 0.5f}, + new Object[]{"", 0.0033333334f}, + new Object[]{"abc", 0f} ); assertResultsEquals(expectedResults, results); Assert.assertEquals( GroupByQuery.builder() - .setDataSource(CalciteTests.DATASOURCE3) - .setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) - .setGranularity(Granularities.ALL) - .setDimensions(new DefaultDimensionSpec("dim2", "_d0")) - .setAggregatorSpecs( - new VarianceAggregatorFactory("a0:agg", "f1", "sample", "float") - ) - .setLimitSpec( - DefaultLimitSpec - .builder() - .orderBy( - new OrderByColumnSpec( - "a0:agg", - OrderByColumnSpec.Direction.DESCENDING, - StringComparators.NUMERIC - ) - ) - .build() - ) - .setContext(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT) - .build(), + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("dim2", "_d0")) + .setAggregatorSpecs( + new VarianceAggregatorFactory("a0:agg", "f1", "sample", "float") + ) + .setLimitSpec( + DefaultLimitSpec + .builder() + .orderBy( + new OrderByColumnSpec( + "a0:agg", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + ) + ) + .build() + ) + .setContext(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT) + .build(), Iterables.getOnlyElement(queryLogHook.getRecordedQueries()) ); } @@ -622,7 +623,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest Arrays.asList( QueryRunnerTestHelper.ROWS_COUNT, QueryRunnerTestHelper.INDEX_DOUBLE_SUM, - new VarianceAggregatorFactory("variance", "index") + new VarianceAggregatorFactory("variance", "index", null, null) ) ) .descending(true) @@ -648,9 +649,18 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest { Assert.assertEquals(expectedResults.size(), results.size()); for (int i = 0; i < expectedResults.size(); i++) { - Assert.assertArrayEquals(expectedResults.get(i), results.get(i)); + Object[] expectedResult = expectedResults.get(i); + Object[] result = results.get(i); + Assert.assertEquals(expectedResult.length, result.length); + for (int j = 0; j < expectedResult.length; j++) { + if (expectedResult[j] instanceof Float) { + Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10); + } else if (expectedResult[j] instanceof Double) { + Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10); + } else { + Assert.assertEquals(expectedResult[j], result[j]); + } + } } } - - }