Vectorized variance aggregators (#10390)

* wip vectorize

* close but not quite

* faster

* unit tests

* fix complex types for variance
This commit is contained in:
Suneet Saldanha 2020-09-17 15:05:40 -07:00 committed by GitHub
parent 1b05d6e542
commit 0b4c897fbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1555 additions and 133 deletions

View File

@ -82,6 +82,11 @@
<artifactId>druid-histogram</artifactId>
<version>${project.parent.version}</version>
</dependency>
<dependency>
<groupId>org.apache.druid.extensions</groupId>
<artifactId>druid-stats</artifactId>
<version>${project.parent.version}</version>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-core</artifactId>
@ -172,7 +177,7 @@
<dependency>
<groupId>org.apache.druid.extensions</groupId>
<artifactId>druid-protobuf-extensions</artifactId>
<version>0.20.0-SNAPSHOT</version>
<version>${project.parent.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

View File

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.benchmark;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 5)
@Measurement(iterations = 5)
public class VarianceBenchmark
{
@Param({"128", "256", "512", "1024"})
int vectorSize;
private float[] randomValues;
@Setup
public void setup()
{
randomValues = new float[vectorSize];
Random r = ThreadLocalRandom.current();
for (int i = 0; i < vectorSize; i++) {
randomValues[i] = r.nextFloat();
}
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void collectVarianceOneByOne(Blackhole blackhole)
{
VarianceAggregatorCollector collector = new VarianceAggregatorCollector();
for (float v : randomValues) {
collector.add(v);
}
blackhole.consume(collector);
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void collectVarianceInBatch(Blackhole blackhole)
{
double sum = 0, nvariance = 0;
for (float v : randomValues) {
sum += v;
}
double mean = sum / randomValues.length;
for (float v : randomValues) {
nvariance += (v - mean) * (v - mean);
}
VarianceAggregatorCollector collector = new VarianceAggregatorCollector(randomValues.length, sum, nvariance);
blackhole.consume(collector);
}
}

View File

@ -76,6 +76,7 @@ public class VarianceAggregatorCollector
if (other == null || other.count == 0) {
return;
}
if (this.count == 0) {
this.nvariance = other.nvariance;
this.count = other.count;

View File

@ -22,6 +22,7 @@ package org.apache.druid.query.aggregation.variance;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
@ -35,12 +36,15 @@ import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.NoopAggregator;
import org.apache.druid.query.aggregation.NoopBufferAggregator;
import org.apache.druid.query.aggregation.ObjectAggregateCombiner;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@ -83,7 +87,8 @@ public class VarianceAggregatorFactory extends AggregatorFactory
this.inputType = inputType;
}
public VarianceAggregatorFactory(String name, String fieldName)
@VisibleForTesting
VarianceAggregatorFactory(String name, String fieldName)
{
this(name, fieldName, null, null);
}
@ -131,7 +136,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return new VarianceAggregator.DoubleVarianceAggregator(selector);
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.LongVarianceAggregator(selector);
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) {
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.ObjectVarianceAggregator(selector);
}
throw new IAE(
@ -156,16 +161,42 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return new VarianceBufferAggregator.DoubleVarianceAggregator(selector);
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.LongVarianceAggregator(selector);
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) {
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.ObjectVarianceAggregator(selector);
}
throw new IAE(
"Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
fieldName,
inputType
type
);
}
@Override
public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory)
{
final String type = getTypeString(selectorFactory);
if (ValueType.FLOAT.name().equalsIgnoreCase(type)) {
return new VarianceFloatVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) {
return new VarianceDoubleVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceLongVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
return new VarianceObjectVectorAggregator(selectorFactory.makeObjectSelector(fieldName));
}
throw new IAE(
"Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
fieldName,
type
);
}
@Override
public boolean canVectorize(ColumnInspector columnInspector)
{
return true;
}
@Override
public Object combine(Object lhs, Object rhs)
{
@ -340,11 +371,11 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return Objects.hash(fieldName, name, estimator, inputType, isVariancePop);
}
private String getTypeString(ColumnSelectorFactory metricFactory)
private String getTypeString(ColumnInspector columnInspector)
{
String type = inputType;
if (type == null) {
ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName);
ColumnCapabilities capabilities = columnInspector.getColumnCapabilities(fieldName);
if (capabilities != null) {
type = StringUtils.toLowerCase(capabilities.getType().name());
} else {
@ -353,5 +384,4 @@ public class VarianceAggregatorFactory extends AggregatorFactory
}
return type;
}
}

View File

@ -35,25 +35,19 @@ import java.nio.ByteBuffer;
public abstract class VarianceBufferAggregator implements BufferAggregator
{
private static final int COUNT_OFFSET = 0;
private static final int SUM_OFFSET = Long.BYTES;
private static final int SUM_OFFSET = COUNT_OFFSET + Long.BYTES;
private static final int NVARIANCE_OFFSET = SUM_OFFSET + Double.BYTES;
@Override
public void init(final ByteBuffer buf, final int position)
{
buf.putLong(position + COUNT_OFFSET, 0)
.putDouble(position + SUM_OFFSET, 0)
.putDouble(position + NVARIANCE_OFFSET, 0);
doInit(buf, position);
}
@Override
public Object get(final ByteBuffer buf, final int position)
public VarianceAggregatorCollector get(final ByteBuffer buf, final int position)
{
VarianceAggregatorCollector holder = new VarianceAggregatorCollector();
holder.count = buf.getLong(position);
holder.sum = buf.getDouble(position + SUM_OFFSET);
holder.nvariance = buf.getDouble(position + NVARIANCE_OFFSET);
return holder;
return getVarianceCollector(buf, position);
}
@Override
@ -79,6 +73,51 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{
}
public static void doInit(ByteBuffer buf, int position)
{
buf.putLong(position + COUNT_OFFSET, 0)
.putDouble(position + SUM_OFFSET, 0)
.putDouble(position + NVARIANCE_OFFSET, 0);
}
public static long getCount(ByteBuffer buf, int position)
{
return buf.getLong(position + COUNT_OFFSET);
}
public static double getSum(ByteBuffer buf, int position)
{
return buf.getDouble(position + SUM_OFFSET);
}
public static double getVariance(ByteBuffer buf, int position)
{
return buf.getDouble(position + NVARIANCE_OFFSET);
}
public static VarianceAggregatorCollector getVarianceCollector(ByteBuffer buf, int position)
{
return new VarianceAggregatorCollector(
getCount(buf, position),
getSum(buf, position),
getVariance(buf, position)
);
}
public static void writeNVariance(ByteBuffer buf, int position, long count, double sum, double nvariance)
{
buf.putLong(position + COUNT_OFFSET, count);
buf.putDouble(position + SUM_OFFSET, sum);
if (count > 1) {
buf.putDouble(position + NVARIANCE_OFFSET, nvariance);
}
}
public static void writeCountAndSum(ByteBuffer buf, int position, long count, double sum)
{
buf.putLong(position + COUNT_OFFSET, count);
buf.putDouble(position + SUM_OFFSET, sum);
}
public static final class FloatVarianceAggregator extends VarianceBufferAggregator
{
private final boolean noNulls = NullHandling.replaceWithDefault();
@ -94,10 +133,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{
if (noNulls || !selector.isNull()) {
float v = selector.getFloat();
long count = buf.getLong(position + COUNT_OFFSET) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v;
buf.putLong(position, count);
buf.putDouble(position + SUM_OFFSET, sum);
long count = getCount(buf, position) + 1;
double sum = getSum(buf, position) + v;
writeCountAndSum(buf, position, count, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@ -128,10 +166,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{
if (noNulls || !selector.isNull()) {
double v = selector.getDouble();
long count = buf.getLong(position + COUNT_OFFSET) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v;
buf.putLong(position, count);
buf.putDouble(position + SUM_OFFSET, sum);
long count = getCount(buf, position) + 1;
double sum = getSum(buf, position) + v;
writeCountAndSum(buf, position, count, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@ -162,10 +199,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{
if (noNulls || !selector.isNull()) {
long v = selector.getLong();
long count = buf.getLong(position + COUNT_OFFSET) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v;
buf.putLong(position, count);
buf.putDouble(position + SUM_OFFSET, sum);
long count = getCount(buf, position) + 1;
double sum = getSum(buf, position) + v;
writeCountAndSum(buf, position, count, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@ -195,7 +231,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{
VarianceAggregatorCollector holder2 = (VarianceAggregatorCollector) selector.getObject();
Preconditions.checkState(holder2 != null);
long count = buf.getLong(position + COUNT_OFFSET);
long count = getCount(buf, position);
if (count == 0) {
buf.putLong(position, holder2.count);
buf.putDouble(position + SUM_OFFSET, holder2.sum);
@ -203,7 +239,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
return;
}
double sum = buf.getDouble(position + SUM_OFFSET);
double sum = getSum(buf, position);
double nvariance = buf.getDouble(position + NVARIANCE_OFFSET);
final double ratio = count / (double) holder2.count;
@ -213,9 +249,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
count += holder2.count;
sum += holder2.sum;
buf.putLong(position, count);
buf.putDouble(position + SUM_OFFSET, sum);
buf.putDouble(position + NVARIANCE_OFFSET, nvariance);
writeNVariance(buf, position, count, sum, nvariance);
}
@Override

View File

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized implementation of {@link VarianceBufferAggregator} for doubles.
*/
public class VarianceDoubleVectorAggregator implements VectorAggregator
{
private final VectorValueSelector selector;
private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
public VarianceDoubleVectorAggregator(VectorValueSelector selector)
{
this.selector = selector;
}
@Override
public void init(ByteBuffer buf, int position)
{
VarianceBufferAggregator.doInit(buf, position);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
double[] vector = selector.getDoubleVector();
long count = 0;
double sum = 0, nvariance = 0;
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
count++;
sum += vector[i];
}
}
double mean = sum / count;
if (count > 1) {
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
nvariance += (vector[i] - mean) * (vector[i] - mean);
}
}
}
VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
VarianceBufferAggregator.getCount(buf, position),
VarianceBufferAggregator.getSum(buf, position),
VarianceBufferAggregator.getVariance(buf, position)
);
previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
@Override
public void aggregate(
ByteBuffer buf,
int numRows,
int[] positions,
@Nullable int[] rows,
int positionOffset
)
{
double[] vector = selector.getDoubleVector();
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = 0; i < numRows; i++) {
int position = positions[i] + positionOffset;
int row = rows != null ? rows[i] : i;
if (nulls == null || !nulls[row]) {
VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
previous.add(vector[row]);
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
}
}
@Nullable
@Override
public VarianceAggregatorCollector get(ByteBuffer buf, int position)
{
return VarianceBufferAggregator.getVarianceCollector(buf, position);
}
@Override
public void close()
{
// Nothing to close.
}
}

View File

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized implementation of {@link VarianceBufferAggregator} for floats.
*/
public class VarianceFloatVectorAggregator implements VectorAggregator
{
private final VectorValueSelector selector;
private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
public VarianceFloatVectorAggregator(VectorValueSelector selector)
{
this.selector = selector;
}
@Override
public void init(ByteBuffer buf, int position)
{
VarianceBufferAggregator.doInit(buf, position);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
float[] vector = selector.getFloatVector();
long count = 0;
double sum = 0, nvariance = 0;
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
count++;
sum += vector[i];
}
}
double mean = sum / count;
if (count > 1) {
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
nvariance += (vector[i] - mean) * (vector[i] - mean);
}
}
}
VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
VarianceBufferAggregator.getCount(buf, position),
VarianceBufferAggregator.getSum(buf, position),
VarianceBufferAggregator.getVariance(buf, position)
);
previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
@Override
public void aggregate(
ByteBuffer buf,
int numRows,
int[] positions,
@Nullable int[] rows,
int positionOffset
)
{
float[] vector = selector.getFloatVector();
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = 0; i < numRows; i++) {
int position = positions[i] + positionOffset;
int row = rows != null ? rows[i] : i;
if (nulls == null || !nulls[row]) {
VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
previous.add(vector[row]);
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
}
}
@Nullable
@Override
public VarianceAggregatorCollector get(ByteBuffer buf, int position)
{
return VarianceBufferAggregator.getVarianceCollector(buf, position);
}
@Override
public void close()
{
// Nothing to close.
}
}

View File

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized implementation of {@link VarianceBufferAggregator} for longs.
*/
public class VarianceLongVectorAggregator implements VectorAggregator
{
private final VectorValueSelector selector;
private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
public VarianceLongVectorAggregator(VectorValueSelector selector)
{
this.selector = selector;
}
@Override
public void init(ByteBuffer buf, int position)
{
VarianceBufferAggregator.doInit(buf, position);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
long[] vector = selector.getLongVector();
long count = 0;
double sum = 0, nvariance = 0;
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
count++;
sum += vector[i];
}
}
double mean = sum / count;
if (count > 1) {
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
nvariance += (vector[i] - mean) * (vector[i] - mean);
}
}
}
VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
VarianceBufferAggregator.getCount(buf, position),
VarianceBufferAggregator.getSum(buf, position),
VarianceBufferAggregator.getVariance(buf, position)
);
previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
@Override
public void aggregate(
ByteBuffer buf,
int numRows,
int[] positions,
@Nullable int[] rows,
int positionOffset
)
{
long[] vector = selector.getLongVector();
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = 0; i < numRows; i++) {
int position = positions[i] + positionOffset;
int row = rows != null ? rows[i] : i;
if (nulls == null || !nulls[row]) {
VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
previous.add(vector[row]);
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
}
}
@Nullable
@Override
public VarianceAggregatorCollector get(ByteBuffer buf, int position)
{
return VarianceBufferAggregator.getVarianceCollector(buf, position);
}
@Override
public void close()
{
// Nothing to close.
}
}

View File

@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorObjectSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized implementation of {@link VarianceBufferAggregator} for {@link VarianceAggregatorCollector}.
*/
public class VarianceObjectVectorAggregator implements VectorAggregator
{
private final VectorObjectSelector selector;
public VarianceObjectVectorAggregator(VectorObjectSelector selector)
{
this.selector = selector;
}
@Override
public void init(ByteBuffer buf, int position)
{
VarianceBufferAggregator.doInit(buf, position);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
VarianceAggregatorCollector[] vector = (VarianceAggregatorCollector[]) selector.getObjectVector();
VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
for (int i = startRow; i < endRow; i++) {
previous.fold(vector[i]);
}
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
@Override
public void aggregate(
ByteBuffer buf,
int numRows,
int[] positions,
@Nullable int[] rows,
int positionOffset
)
{
VarianceAggregatorCollector[] vector = (VarianceAggregatorCollector[]) selector.getObjectVector();
for (int i = 0; i < numRows; i++) {
int position = positions[i] + positionOffset;
int row = rows != null ? rows[i] : i;
VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
previous.fold(vector[row]);
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
}
@Nullable
@Override
public VarianceAggregatorCollector get(ByteBuffer buf, int position)
{
return VarianceBufferAggregator.getVarianceCollector(buf, position);
}
@Override
public void close()
{
// Nothing to close.
}
}

View File

@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
@RunWith(MockitoJUnitRunner.class)
public class VarianceAggregatorFactoryUnitTest extends InitializedNullHandlingTest
{
private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME";
private static final String DOUBLE = "double";
private static final String LONG = "long";
private static final String VARIANCE = "variance";
private static final String UNKNOWN = "unknown";
@Mock
private ColumnCapabilities capabilities;
@Mock
private VectorColumnSelectorFactory selectorFactory;
@Mock(answer = Answers.RETURNS_MOCKS)
private ColumnSelectorFactory metricFactory;
private VarianceAggregatorFactory target;
@Before
public void setup()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME);
}
@Test
public void factorizeVectorShouldReturnFloatVectorAggregator()
{
VectorAggregator agg = target.factorizeVector(selectorFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceFloatVectorAggregator.class, agg.getClass());
}
@Test
public void factorizeVectorForDoubleShouldReturnFloatVectorAggregator()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, DOUBLE);
VectorAggregator agg = target.factorizeVector(selectorFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceDoubleVectorAggregator.class, agg.getClass());
}
@Test
public void factorizeVectorForLongShouldReturnFloatVectorAggregator()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, LONG);
VectorAggregator agg = target.factorizeVector(selectorFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceLongVectorAggregator.class, agg.getClass());
}
@Test
public void factorizeVectorForVarianceShouldReturnObjectVectorAggregator()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, VARIANCE);
VectorAggregator agg = target.factorizeVector(selectorFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceObjectVectorAggregator.class, agg.getClass());
}
@Test
public void factorizeVectorForComplexShouldReturnObjectVectorAggregator()
{
mockType(ValueType.COMPLEX);
VectorAggregator agg = target.factorizeVector(selectorFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceObjectVectorAggregator.class, agg.getClass());
}
@Test
public void factorizeBufferedForComplexShouldReturnObjectVectorAggregator()
{
mockType(ValueType.COMPLEX);
BufferAggregator agg = target.factorizeBuffered(metricFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceBufferAggregator.ObjectVarianceAggregator.class, agg.getClass());
}
@Test
public void factorizeForComplexShouldReturnObjectVectorAggregator()
{
mockType(ValueType.COMPLEX);
Aggregator agg = target.factorize(metricFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceAggregator.ObjectVarianceAggregator.class, agg.getClass());
}
@Test(expected = IAE.class)
public void factorizeVectorForUnknownColumnShouldThrowIAE()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, UNKNOWN);
target.factorizeVector(selectorFactory);
}
@Test(expected = IAE.class)
public void factorizeBufferedForUnknownColumnShouldThrowIAE()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, UNKNOWN);
target.factorizeBuffered(metricFactory);
}
@Test
public void equalsContract()
{
EqualsVerifier.forClass(VarianceAggregatorFactory.class)
.usingGetClass()
.verify();
}
private void mockType(ValueType type)
{
Mockito.doReturn(capabilities).when(selectorFactory).getColumnCapabilities(FIELD_NAME);
Mockito.doReturn(capabilities).when(metricFactory).getColumnCapabilities(FIELD_NAME);
Mockito.doReturn(type).when(capabilities).getType();
}
}

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class VarianceDoubleVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final int START_ROW = 1;
private static final int POSITION = 2;
private static final int UNINIT_POSITION = 512;
private static final double EPSILON = 1e-10;
private static final double[] VALUES = new double[]{7.8d, 11, 23.67, 60, 123};
private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
@Mock
private VectorValueSelector selector;
private ByteBuffer buf;
private VarianceDoubleVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getDoubleVector();
target = new VarianceDoubleVectorAggregator(selector);
clearBufferForPositions(0, POSITION);
}
@Test
public void initValueShouldInitZero()
{
target.init(buf, UNINIT_POSITION);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
@Test
public void aggregate()
{
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(VALUES.length - START_ROW, collector.count);
Assert.assertEquals(217.67, collector.sum, EPSILON);
Assert.assertEquals(7565.211675, collector.nvariance, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(
VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
collector.count
);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 217.67 : 134, collector.sum, EPSILON);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 7565.211675 : 6272, collector.nvariance, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRowsAndNulls()
{
mockNullsVector();
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
Assert.assertEquals(isNull ? 0 : 1, collector.count);
Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void getShouldReturnAllZeros()
{
VarianceAggregatorCollector collector = target.get(buf, POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
VarianceBufferAggregator.doInit(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class VarianceFloatVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final int START_ROW = 1;
private static final int POSITION = 2;
private static final int UNINIT_POSITION = 512;
private static final double EPSILON = 1e-8;
private static final float[] VALUES = new float[]{7.8F, 11, 23.67F, 60, 123};
private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
@Mock
private VectorValueSelector selector;
private ByteBuffer buf;
private VarianceFloatVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getFloatVector();
target = new VarianceFloatVectorAggregator(selector);
clearBufferForPositions(0, POSITION);
}
@Test
public void initValueShouldInitZero()
{
target.init(buf, UNINIT_POSITION);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
@Test
public void aggregate()
{
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(VALUES.length - START_ROW, collector.count);
Assert.assertEquals(217.67000007, collector.sum, EPSILON);
Assert.assertEquals(7565.2116703, collector.nvariance, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(
VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
collector.count
);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 217.67000007 : 134, collector.sum, EPSILON);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 7565.2116703 : 6272, collector.nvariance, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRowsAndNulls()
{
mockNullsVector();
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
Assert.assertEquals(isNull ? 0 : 1, collector.count);
Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void getShouldReturnAllZeros()
{
VarianceAggregatorCollector collector = target.get(buf, POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
VarianceBufferAggregator.doInit(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -20,6 +20,7 @@
package org.apache.druid.query.aggregation.variance;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.data.input.Row;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.PeriodGranularity;
@ -63,14 +64,12 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
private final QueryRunner<Row> runner;
private final GroupByQueryRunnerFactory factory;
private final String testName;
private final GroupByQuery.Builder queryBuilder;
@Parameterized.Parameters(name = "{0}")
public static Collection<Object[]> constructorFeeder()
{
// Use GroupByQueryRunnerTest's constructorFeeder, but remove vectorized tests, since this aggregator
// can't vectorize yet.
return GroupByQueryRunnerTest.constructorFeeder().stream()
.filter(constructor -> !((boolean) constructor[4]) /* !vectorize */)
.map(
constructor ->
new Object[]{
@ -94,13 +93,14 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
this.config = config;
this.factory = factory;
this.runner = factory.mergeRunners(Execs.directExecutor(), ImmutableList.of(runner));
this.queryBuilder = GroupByQuery.builder()
.setContext(ImmutableMap.of("vectorize", config.isVectorize()));
}
@Test
public void testGroupByVarianceOnly()
{
GroupByQuery query = GroupByQuery
.builder()
GroupByQuery query = queryBuilder
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
@ -141,8 +141,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
@Test
public void testGroupBy()
{
GroupByQuery query = GroupByQuery
.builder()
GroupByQuery query = queryBuilder
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
@ -191,8 +190,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
new String[]{"alias", "rows", "index", "index_var", "index_stddev"}
);
GroupByQuery query = GroupByQuery
.builder()
GroupByQuery query = queryBuilder
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setInterval("2011-04-02/2011-04-04")
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
@ -244,8 +242,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
public void testGroupByZtestPostAgg()
{
// test postaggs from 'teststats' package in here since we've already gone to the trouble of setting up the test
GroupByQuery query = GroupByQuery
.builder()
GroupByQuery query = queryBuilder
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
@ -286,8 +283,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
public void testGroupByTestPvalueZscorePostAgg()
{
// test postaggs from 'teststats' package in here since we've already gone to the trouble of setting up the test
GroupByQuery query = GroupByQuery
.builder()
GroupByQuery query = queryBuilder
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
@ -308,7 +304,14 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
.build();
VarianceTestHelper.RowBuilder builder =
new VarianceTestHelper.RowBuilder(new String[]{"alias", "rows", "idx", "index_stddev", "index_var", "pvalueZscore"});
new VarianceTestHelper.RowBuilder(new String[]{
"alias",
"rows",
"idx",
"index_stddev",
"index_var",
"pvalueZscore"
});
List<ResultRow> expectedResults = builder
.add("2011-04-01", "automotive", 1L, 135.0, 0.0, 0.0, 1.0)

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class VarianceLongVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final int START_ROW = 1;
private static final int POSITION = 2;
private static final int UNINIT_POSITION = 512;
private static final double EPSILON = 1e-10;
private static final long[] VALUES = new long[]{7, 11, 23, 60, 123};
private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
@Mock
private VectorValueSelector selector;
private ByteBuffer buf;
private VarianceLongVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getLongVector();
target = new VarianceLongVectorAggregator(selector);
clearBufferForPositions(0, POSITION);
}
@Test
public void initValueShouldInitZero()
{
target.init(buf, UNINIT_POSITION);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
@Test
public void aggregate()
{
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(VALUES.length - START_ROW, collector.count);
Assert.assertEquals(217, collector.sum, EPSILON);
Assert.assertEquals(7606.75, collector.nvariance, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(
VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
collector.count
);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 217 : 134, collector.sum, EPSILON);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 7606.75 : 6272, collector.nvariance, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRowsAndNulls()
{
mockNullsVector();
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
Assert.assertEquals(isNull ? 0 : 1, collector.count);
Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void getShouldReturnAllZeros()
{
VarianceAggregatorCollector collector = target.get(buf, POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
VarianceBufferAggregator.doInit(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class VarianceObjectVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final int START_ROW = 1;
private static final int POSITION = 2;
private static final int UNINIT_POSITION = 512;
private static final double EPSILON = 1e-10;
private static final VarianceAggregatorCollector[] VALUES = new VarianceAggregatorCollector[]{
new VarianceAggregatorCollector(1, 7.8, 0),
new VarianceAggregatorCollector(1, 11, 0),
new VarianceAggregatorCollector(1, 23.67, 0),
null,
new VarianceAggregatorCollector(2, 183, 1984.5)
};
private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
@Mock
private VectorObjectSelector selector;
private ByteBuffer buf;
private VarianceObjectVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getObjectVector();
target = new VarianceObjectVectorAggregator(selector);
clearBufferForPositions(0, POSITION);
}
@Test
public void initValueShouldInitZero()
{
target.init(buf, UNINIT_POSITION);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
@Test
public void aggregate()
{
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(4, collector.count);
Assert.assertEquals(217.67, collector.sum, EPSILON);
Assert.assertEquals(7565.211675, collector.nvariance, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(VALUES[i], collector);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
VarianceAggregatorCollector expectedCollector = VALUES[rows[i]];
Assert.assertEquals(expectedCollector == null ? new VarianceAggregatorCollector() : expectedCollector, collector);
}
}
@Test
public void getShouldReturnAllZeros()
{
VarianceAggregatorCollector collector = target.get(buf, POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
VarianceBufferAggregator.doInit(buf, offset + position);
}
}
}

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.variance;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryPlus;
@ -46,31 +47,32 @@ public class VarianceTimeseriesQueryTest extends InitializedNullHandlingTest
@Parameterized.Parameters(name = "{0}:descending={1}")
public static Iterable<Object[]> constructorFeeder()
{
// Use TimeseriesQueryRunnerTest's constructorFeeder, but remove vectorized tests, since this aggregator
// can't vectorize yet.
return StreamSupport.stream(TimeseriesQueryRunnerTest.constructorFeeder().spliterator(), false)
.filter(constructor -> !((boolean) constructor[2]) /* !vectorize */)
.map(constructor -> new Object[]{constructor[0], constructor[1], constructor[3]})
.map(constructor -> new Object[]{constructor[0], constructor[1], constructor[2], constructor[3]})
.collect(Collectors.toList());
}
private final QueryRunner runner;
private final boolean descending;
private final Druids.TimeseriesQueryBuilder queryBuilder;
public VarianceTimeseriesQueryTest(
QueryRunner runner,
boolean descending,
boolean vectorize,
List<AggregatorFactory> aggregatorFactories
)
{
this.runner = runner;
this.descending = descending;
this.queryBuilder = Druids.newTimeseriesQueryBuilder()
.context(ImmutableMap.of("vectorize", vectorize ? "force" : "false"));
}
@Test
public void testTimeseriesWithNullFilterOnNonExistentDimension()
{
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
TimeseriesQuery query = queryBuilder
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
.granularity(QueryRunnerTestHelper.DAY_GRAN)
.filters("bobby", null)

View File

@ -285,16 +285,16 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
)
)
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
)
)
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
@ -335,22 +335,22 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
holder1.getVariance(false),
holder2.getVariance(false).floatValue(),
holder3.getVariance(false).longValue(),
}
}
);
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
)
)
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
)
)
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
@ -391,28 +391,29 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
Math.sqrt(holder1.getVariance(true)),
(float) Math.sqrt(holder2.getVariance(true)),
(long) Math.sqrt(holder3.getVariance(true)),
}
}
);
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
)
)
.postAggregators(
ImmutableList.of(
new StandardDeviationPostAggregator("a0", "a0:agg", "population"),
new StandardDeviationPostAggregator("a1", "a1:agg", "population"),
new StandardDeviationPostAggregator("a2", "a2:agg", "population"))
)
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
)
)
.postAggregators(
ImmutableList.of(
new StandardDeviationPostAggregator("a0", "a0:agg", "population"),
new StandardDeviationPostAggregator("a1", "a1:agg", "population"),
new StandardDeviationPostAggregator("a2", "a2:agg", "population")
)
)
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
@ -453,7 +454,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
Math.sqrt(holder1.getVariance(false)),
(float) Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
}
}
);
assertResultsEquals(expectedResults, results);
@ -464,9 +465,9 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"),
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
)
)
.postAggregators(
@ -514,7 +515,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
Math.sqrt(holder1.getVariance(false)),
(float) Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
}
}
);
assertResultsEquals(expectedResults, results);
@ -530,9 +531,9 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "v0", "sample", "double"),
new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long")
new VarianceAggregatorFactory("a0:agg", "v0", "sample", "double"),
new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long")
)
)
.postAggregators(
@ -560,41 +561,41 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
authenticationResult
).toList();
List<Object[]> expectedResults = NullHandling.sqlCompatible()
? ImmutableList.of(
new Object[] {"a", 0f},
new Object[] {null, 0f},
new Object[] {"", 0f},
new Object[] {"abc", null}
? ImmutableList.of(
new Object[]{"a", 0f},
new Object[]{null, 0f},
new Object[]{"", 0f},
new Object[]{"abc", null}
) : ImmutableList.of(
new Object[] {"a", 0.5f},
new Object[] {"", 0.0033333334f},
new Object[] {"abc", 0f}
new Object[]{"a", 0.5f},
new Object[]{"", 0.0033333334f},
new Object[]{"abc", 0f}
);
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE3)
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL)
.setDimensions(new DefaultDimensionSpec("dim2", "_d0"))
.setAggregatorSpecs(
new VarianceAggregatorFactory("a0:agg", "f1", "sample", "float")
)
.setLimitSpec(
DefaultLimitSpec
.builder()
.orderBy(
new OrderByColumnSpec(
"a0:agg",
OrderByColumnSpec.Direction.DESCENDING,
StringComparators.NUMERIC
)
)
.build()
)
.setContext(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
.build(),
.setDataSource(CalciteTests.DATASOURCE3)
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL)
.setDimensions(new DefaultDimensionSpec("dim2", "_d0"))
.setAggregatorSpecs(
new VarianceAggregatorFactory("a0:agg", "f1", "sample", "float")
)
.setLimitSpec(
DefaultLimitSpec
.builder()
.orderBy(
new OrderByColumnSpec(
"a0:agg",
OrderByColumnSpec.Direction.DESCENDING,
StringComparators.NUMERIC
)
)
.build()
)
.setContext(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@ -622,7 +623,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
Arrays.asList(
QueryRunnerTestHelper.ROWS_COUNT,
QueryRunnerTestHelper.INDEX_DOUBLE_SUM,
new VarianceAggregatorFactory("variance", "index")
new VarianceAggregatorFactory("variance", "index", null, null)
)
)
.descending(true)
@ -648,9 +649,18 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
{
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
Object[] expectedResult = expectedResults.get(i);
Object[] result = results.get(i);
Assert.assertEquals(expectedResult.length, result.length);
for (int j = 0; j < expectedResult.length; j++) {
if (expectedResult[j] instanceof Float) {
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10);
} else if (expectedResult[j] instanceof Double) {
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10);
} else {
Assert.assertEquals(expectedResult[j], result[j]);
}
}
}
}
}