Vectorized variance aggregators (#10390)

* wip vectorize

* close but not quite

* faster

* unit tests

* fix complex types for variance
This commit is contained in:
Suneet Saldanha 2020-09-17 15:05:40 -07:00 committed by GitHub
parent 1b05d6e542
commit 0b4c897fbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1555 additions and 133 deletions

View File

@ -82,6 +82,11 @@
<artifactId>druid-histogram</artifactId> <artifactId>druid-histogram</artifactId>
<version>${project.parent.version}</version> <version>${project.parent.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.apache.druid.extensions</groupId>
<artifactId>druid-stats</artifactId>
<version>${project.parent.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.apache.druid</groupId> <groupId>org.apache.druid</groupId>
<artifactId>druid-core</artifactId> <artifactId>druid-core</artifactId>
@ -172,7 +177,7 @@
<dependency> <dependency>
<groupId>org.apache.druid.extensions</groupId> <groupId>org.apache.druid.extensions</groupId>
<artifactId>druid-protobuf-extensions</artifactId> <artifactId>druid-protobuf-extensions</artifactId>
<version>0.20.0-SNAPSHOT</version> <version>${project.parent.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>

View File

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.benchmark;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 5)
@Measurement(iterations = 5)
public class VarianceBenchmark
{
@Param({"128", "256", "512", "1024"})
int vectorSize;
private float[] randomValues;
@Setup
public void setup()
{
randomValues = new float[vectorSize];
Random r = ThreadLocalRandom.current();
for (int i = 0; i < vectorSize; i++) {
randomValues[i] = r.nextFloat();
}
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void collectVarianceOneByOne(Blackhole blackhole)
{
VarianceAggregatorCollector collector = new VarianceAggregatorCollector();
for (float v : randomValues) {
collector.add(v);
}
blackhole.consume(collector);
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public void collectVarianceInBatch(Blackhole blackhole)
{
double sum = 0, nvariance = 0;
for (float v : randomValues) {
sum += v;
}
double mean = sum / randomValues.length;
for (float v : randomValues) {
nvariance += (v - mean) * (v - mean);
}
VarianceAggregatorCollector collector = new VarianceAggregatorCollector(randomValues.length, sum, nvariance);
blackhole.consume(collector);
}
}

View File

@ -76,6 +76,7 @@ public class VarianceAggregatorCollector
if (other == null || other.count == 0) { if (other == null || other.count == 0) {
return; return;
} }
if (this.count == 0) { if (this.count == 0) {
this.nvariance = other.nvariance; this.nvariance = other.nvariance;
this.count = other.count; this.count = other.count;

View File

@ -22,6 +22,7 @@ package org.apache.druid.query.aggregation.variance;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
@ -35,12 +36,15 @@ import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.NoopAggregator; import org.apache.druid.query.aggregation.NoopAggregator;
import org.apache.druid.query.aggregation.NoopBufferAggregator; import org.apache.druid.query.aggregation.NoopBufferAggregator;
import org.apache.druid.query.aggregation.ObjectAggregateCombiner; import org.apache.druid.query.aggregation.ObjectAggregateCombiner;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector; import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -83,7 +87,8 @@ public class VarianceAggregatorFactory extends AggregatorFactory
this.inputType = inputType; this.inputType = inputType;
} }
public VarianceAggregatorFactory(String name, String fieldName) @VisibleForTesting
VarianceAggregatorFactory(String name, String fieldName)
{ {
this(name, fieldName, null, null); this(name, fieldName, null, null);
} }
@ -131,7 +136,7 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return new VarianceAggregator.DoubleVarianceAggregator(selector); return new VarianceAggregator.DoubleVarianceAggregator(selector);
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) { } else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.LongVarianceAggregator(selector); return new VarianceAggregator.LongVarianceAggregator(selector);
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) { } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
return new VarianceAggregator.ObjectVarianceAggregator(selector); return new VarianceAggregator.ObjectVarianceAggregator(selector);
} }
throw new IAE( throw new IAE(
@ -156,16 +161,42 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return new VarianceBufferAggregator.DoubleVarianceAggregator(selector); return new VarianceBufferAggregator.DoubleVarianceAggregator(selector);
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) { } else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.LongVarianceAggregator(selector); return new VarianceBufferAggregator.LongVarianceAggregator(selector);
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type)) { } else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
return new VarianceBufferAggregator.ObjectVarianceAggregator(selector); return new VarianceBufferAggregator.ObjectVarianceAggregator(selector);
} }
throw new IAE( throw new IAE(
"Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s", "Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
fieldName, fieldName,
inputType type
); );
} }
@Override
public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory)
{
final String type = getTypeString(selectorFactory);
if (ValueType.FLOAT.name().equalsIgnoreCase(type)) {
return new VarianceFloatVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else if (ValueType.DOUBLE.name().equalsIgnoreCase(type)) {
return new VarianceDoubleVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else if (ValueType.LONG.name().equalsIgnoreCase(type)) {
return new VarianceLongVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else if (VARIANCE_TYPE_NAME.equalsIgnoreCase(type) || ValueType.COMPLEX.name().equalsIgnoreCase(type)) {
return new VarianceObjectVectorAggregator(selectorFactory.makeObjectSelector(fieldName));
}
throw new IAE(
"Incompatible type for metric[%s], expected a float, double, long, or variance, but got a %s",
fieldName,
type
);
}
@Override
public boolean canVectorize(ColumnInspector columnInspector)
{
return true;
}
@Override @Override
public Object combine(Object lhs, Object rhs) public Object combine(Object lhs, Object rhs)
{ {
@ -340,11 +371,11 @@ public class VarianceAggregatorFactory extends AggregatorFactory
return Objects.hash(fieldName, name, estimator, inputType, isVariancePop); return Objects.hash(fieldName, name, estimator, inputType, isVariancePop);
} }
private String getTypeString(ColumnSelectorFactory metricFactory) private String getTypeString(ColumnInspector columnInspector)
{ {
String type = inputType; String type = inputType;
if (type == null) { if (type == null) {
ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName); ColumnCapabilities capabilities = columnInspector.getColumnCapabilities(fieldName);
if (capabilities != null) { if (capabilities != null) {
type = StringUtils.toLowerCase(capabilities.getType().name()); type = StringUtils.toLowerCase(capabilities.getType().name());
} else { } else {
@ -353,5 +384,4 @@ public class VarianceAggregatorFactory extends AggregatorFactory
} }
return type; return type;
} }
} }

View File

@ -35,25 +35,19 @@ import java.nio.ByteBuffer;
public abstract class VarianceBufferAggregator implements BufferAggregator public abstract class VarianceBufferAggregator implements BufferAggregator
{ {
private static final int COUNT_OFFSET = 0; private static final int COUNT_OFFSET = 0;
private static final int SUM_OFFSET = Long.BYTES; private static final int SUM_OFFSET = COUNT_OFFSET + Long.BYTES;
private static final int NVARIANCE_OFFSET = SUM_OFFSET + Double.BYTES; private static final int NVARIANCE_OFFSET = SUM_OFFSET + Double.BYTES;
@Override @Override
public void init(final ByteBuffer buf, final int position) public void init(final ByteBuffer buf, final int position)
{ {
buf.putLong(position + COUNT_OFFSET, 0) doInit(buf, position);
.putDouble(position + SUM_OFFSET, 0)
.putDouble(position + NVARIANCE_OFFSET, 0);
} }
@Override @Override
public Object get(final ByteBuffer buf, final int position) public VarianceAggregatorCollector get(final ByteBuffer buf, final int position)
{ {
VarianceAggregatorCollector holder = new VarianceAggregatorCollector(); return getVarianceCollector(buf, position);
holder.count = buf.getLong(position);
holder.sum = buf.getDouble(position + SUM_OFFSET);
holder.nvariance = buf.getDouble(position + NVARIANCE_OFFSET);
return holder;
} }
@Override @Override
@ -79,6 +73,51 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{ {
} }
public static void doInit(ByteBuffer buf, int position)
{
buf.putLong(position + COUNT_OFFSET, 0)
.putDouble(position + SUM_OFFSET, 0)
.putDouble(position + NVARIANCE_OFFSET, 0);
}
public static long getCount(ByteBuffer buf, int position)
{
return buf.getLong(position + COUNT_OFFSET);
}
public static double getSum(ByteBuffer buf, int position)
{
return buf.getDouble(position + SUM_OFFSET);
}
public static double getVariance(ByteBuffer buf, int position)
{
return buf.getDouble(position + NVARIANCE_OFFSET);
}
public static VarianceAggregatorCollector getVarianceCollector(ByteBuffer buf, int position)
{
return new VarianceAggregatorCollector(
getCount(buf, position),
getSum(buf, position),
getVariance(buf, position)
);
}
public static void writeNVariance(ByteBuffer buf, int position, long count, double sum, double nvariance)
{
buf.putLong(position + COUNT_OFFSET, count);
buf.putDouble(position + SUM_OFFSET, sum);
if (count > 1) {
buf.putDouble(position + NVARIANCE_OFFSET, nvariance);
}
}
public static void writeCountAndSum(ByteBuffer buf, int position, long count, double sum)
{
buf.putLong(position + COUNT_OFFSET, count);
buf.putDouble(position + SUM_OFFSET, sum);
}
public static final class FloatVarianceAggregator extends VarianceBufferAggregator public static final class FloatVarianceAggregator extends VarianceBufferAggregator
{ {
private final boolean noNulls = NullHandling.replaceWithDefault(); private final boolean noNulls = NullHandling.replaceWithDefault();
@ -94,10 +133,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{ {
if (noNulls || !selector.isNull()) { if (noNulls || !selector.isNull()) {
float v = selector.getFloat(); float v = selector.getFloat();
long count = buf.getLong(position + COUNT_OFFSET) + 1; long count = getCount(buf, position) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v; double sum = getSum(buf, position) + v;
buf.putLong(position, count); writeCountAndSum(buf, position, count, sum);
buf.putDouble(position + SUM_OFFSET, sum);
if (count > 1) { if (count > 1) {
double t = count * v - sum; double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1)); double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@ -128,10 +166,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{ {
if (noNulls || !selector.isNull()) { if (noNulls || !selector.isNull()) {
double v = selector.getDouble(); double v = selector.getDouble();
long count = buf.getLong(position + COUNT_OFFSET) + 1; long count = getCount(buf, position) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v; double sum = getSum(buf, position) + v;
buf.putLong(position, count); writeCountAndSum(buf, position, count, sum);
buf.putDouble(position + SUM_OFFSET, sum);
if (count > 1) { if (count > 1) {
double t = count * v - sum; double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1)); double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@ -162,10 +199,9 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{ {
if (noNulls || !selector.isNull()) { if (noNulls || !selector.isNull()) {
long v = selector.getLong(); long v = selector.getLong();
long count = buf.getLong(position + COUNT_OFFSET) + 1; long count = getCount(buf, position) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v; double sum = getSum(buf, position) + v;
buf.putLong(position, count); writeCountAndSum(buf, position, count, sum);
buf.putDouble(position + SUM_OFFSET, sum);
if (count > 1) { if (count > 1) {
double t = count * v - sum; double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1)); double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
@ -195,7 +231,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
{ {
VarianceAggregatorCollector holder2 = (VarianceAggregatorCollector) selector.getObject(); VarianceAggregatorCollector holder2 = (VarianceAggregatorCollector) selector.getObject();
Preconditions.checkState(holder2 != null); Preconditions.checkState(holder2 != null);
long count = buf.getLong(position + COUNT_OFFSET); long count = getCount(buf, position);
if (count == 0) { if (count == 0) {
buf.putLong(position, holder2.count); buf.putLong(position, holder2.count);
buf.putDouble(position + SUM_OFFSET, holder2.sum); buf.putDouble(position + SUM_OFFSET, holder2.sum);
@ -203,7 +239,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
return; return;
} }
double sum = buf.getDouble(position + SUM_OFFSET); double sum = getSum(buf, position);
double nvariance = buf.getDouble(position + NVARIANCE_OFFSET); double nvariance = buf.getDouble(position + NVARIANCE_OFFSET);
final double ratio = count / (double) holder2.count; final double ratio = count / (double) holder2.count;
@ -213,9 +249,7 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
count += holder2.count; count += holder2.count;
sum += holder2.sum; sum += holder2.sum;
buf.putLong(position, count); writeNVariance(buf, position, count, sum, nvariance);
buf.putDouble(position + SUM_OFFSET, sum);
buf.putDouble(position + NVARIANCE_OFFSET, nvariance);
} }
@Override @Override

View File

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized implementation of {@link VarianceBufferAggregator} for doubles.
*/
public class VarianceDoubleVectorAggregator implements VectorAggregator
{
private final VectorValueSelector selector;
private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
public VarianceDoubleVectorAggregator(VectorValueSelector selector)
{
this.selector = selector;
}
@Override
public void init(ByteBuffer buf, int position)
{
VarianceBufferAggregator.doInit(buf, position);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
double[] vector = selector.getDoubleVector();
long count = 0;
double sum = 0, nvariance = 0;
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
count++;
sum += vector[i];
}
}
double mean = sum / count;
if (count > 1) {
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
nvariance += (vector[i] - mean) * (vector[i] - mean);
}
}
}
VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
VarianceBufferAggregator.getCount(buf, position),
VarianceBufferAggregator.getSum(buf, position),
VarianceBufferAggregator.getVariance(buf, position)
);
previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
@Override
public void aggregate(
ByteBuffer buf,
int numRows,
int[] positions,
@Nullable int[] rows,
int positionOffset
)
{
double[] vector = selector.getDoubleVector();
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = 0; i < numRows; i++) {
int position = positions[i] + positionOffset;
int row = rows != null ? rows[i] : i;
if (nulls == null || !nulls[row]) {
VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
previous.add(vector[row]);
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
}
}
@Nullable
@Override
public VarianceAggregatorCollector get(ByteBuffer buf, int position)
{
return VarianceBufferAggregator.getVarianceCollector(buf, position);
}
@Override
public void close()
{
// Nothing to close.
}
}

View File

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized implementation of {@link VarianceBufferAggregator} for floats.
*/
public class VarianceFloatVectorAggregator implements VectorAggregator
{
private final VectorValueSelector selector;
private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
public VarianceFloatVectorAggregator(VectorValueSelector selector)
{
this.selector = selector;
}
@Override
public void init(ByteBuffer buf, int position)
{
VarianceBufferAggregator.doInit(buf, position);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
float[] vector = selector.getFloatVector();
long count = 0;
double sum = 0, nvariance = 0;
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
count++;
sum += vector[i];
}
}
double mean = sum / count;
if (count > 1) {
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
nvariance += (vector[i] - mean) * (vector[i] - mean);
}
}
}
VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
VarianceBufferAggregator.getCount(buf, position),
VarianceBufferAggregator.getSum(buf, position),
VarianceBufferAggregator.getVariance(buf, position)
);
previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
@Override
public void aggregate(
ByteBuffer buf,
int numRows,
int[] positions,
@Nullable int[] rows,
int positionOffset
)
{
float[] vector = selector.getFloatVector();
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = 0; i < numRows; i++) {
int position = positions[i] + positionOffset;
int row = rows != null ? rows[i] : i;
if (nulls == null || !nulls[row]) {
VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
previous.add(vector[row]);
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
}
}
@Nullable
@Override
public VarianceAggregatorCollector get(ByteBuffer buf, int position)
{
return VarianceBufferAggregator.getVarianceCollector(buf, position);
}
@Override
public void close()
{
// Nothing to close.
}
}

View File

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized implementation of {@link VarianceBufferAggregator} for longs.
*/
public class VarianceLongVectorAggregator implements VectorAggregator
{
private final VectorValueSelector selector;
private final boolean replaceWithDefault = NullHandling.replaceWithDefault();
public VarianceLongVectorAggregator(VectorValueSelector selector)
{
this.selector = selector;
}
@Override
public void init(ByteBuffer buf, int position)
{
VarianceBufferAggregator.doInit(buf, position);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
long[] vector = selector.getLongVector();
long count = 0;
double sum = 0, nvariance = 0;
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
count++;
sum += vector[i];
}
}
double mean = sum / count;
if (count > 1) {
for (int i = startRow; i < endRow; i++) {
if (nulls == null || !nulls[i]) {
nvariance += (vector[i] - mean) * (vector[i] - mean);
}
}
}
VarianceAggregatorCollector previous = new VarianceAggregatorCollector(
VarianceBufferAggregator.getCount(buf, position),
VarianceBufferAggregator.getSum(buf, position),
VarianceBufferAggregator.getVariance(buf, position)
);
previous.fold(new VarianceAggregatorCollector(count, sum, nvariance));
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
@Override
public void aggregate(
ByteBuffer buf,
int numRows,
int[] positions,
@Nullable int[] rows,
int positionOffset
)
{
long[] vector = selector.getLongVector();
boolean[] nulls = replaceWithDefault ? null : selector.getNullVector();
for (int i = 0; i < numRows; i++) {
int position = positions[i] + positionOffset;
int row = rows != null ? rows[i] : i;
if (nulls == null || !nulls[row]) {
VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
previous.add(vector[row]);
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
}
}
@Nullable
@Override
public VarianceAggregatorCollector get(ByteBuffer buf, int position)
{
return VarianceBufferAggregator.getVarianceCollector(buf, position);
}
@Override
public void close()
{
// Nothing to close.
}
}

View File

@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorObjectSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized implementation of {@link VarianceBufferAggregator} for {@link VarianceAggregatorCollector}.
*/
public class VarianceObjectVectorAggregator implements VectorAggregator
{
private final VectorObjectSelector selector;
public VarianceObjectVectorAggregator(VectorObjectSelector selector)
{
this.selector = selector;
}
@Override
public void init(ByteBuffer buf, int position)
{
VarianceBufferAggregator.doInit(buf, position);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
VarianceAggregatorCollector[] vector = (VarianceAggregatorCollector[]) selector.getObjectVector();
VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
for (int i = startRow; i < endRow; i++) {
previous.fold(vector[i]);
}
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
@Override
public void aggregate(
ByteBuffer buf,
int numRows,
int[] positions,
@Nullable int[] rows,
int positionOffset
)
{
VarianceAggregatorCollector[] vector = (VarianceAggregatorCollector[]) selector.getObjectVector();
for (int i = 0; i < numRows; i++) {
int position = positions[i] + positionOffset;
int row = rows != null ? rows[i] : i;
VarianceAggregatorCollector previous = VarianceBufferAggregator.getVarianceCollector(buf, position);
previous.fold(vector[row]);
VarianceBufferAggregator.writeNVariance(buf, position, previous.count, previous.sum, previous.nvariance);
}
}
@Nullable
@Override
public VarianceAggregatorCollector get(ByteBuffer buf, int position)
{
return VarianceBufferAggregator.getVarianceCollector(buf, position);
}
@Override
public void close()
{
// Nothing to close.
}
}

View File

@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
@RunWith(MockitoJUnitRunner.class)
public class VarianceAggregatorFactoryUnitTest extends InitializedNullHandlingTest
{
private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME";
private static final String DOUBLE = "double";
private static final String LONG = "long";
private static final String VARIANCE = "variance";
private static final String UNKNOWN = "unknown";
@Mock
private ColumnCapabilities capabilities;
@Mock
private VectorColumnSelectorFactory selectorFactory;
@Mock(answer = Answers.RETURNS_MOCKS)
private ColumnSelectorFactory metricFactory;
private VarianceAggregatorFactory target;
@Before
public void setup()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME);
}
@Test
public void factorizeVectorShouldReturnFloatVectorAggregator()
{
VectorAggregator agg = target.factorizeVector(selectorFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceFloatVectorAggregator.class, agg.getClass());
}
@Test
public void factorizeVectorForDoubleShouldReturnFloatVectorAggregator()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, DOUBLE);
VectorAggregator agg = target.factorizeVector(selectorFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceDoubleVectorAggregator.class, agg.getClass());
}
@Test
public void factorizeVectorForLongShouldReturnFloatVectorAggregator()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, LONG);
VectorAggregator agg = target.factorizeVector(selectorFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceLongVectorAggregator.class, agg.getClass());
}
@Test
public void factorizeVectorForVarianceShouldReturnObjectVectorAggregator()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, VARIANCE);
VectorAggregator agg = target.factorizeVector(selectorFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceObjectVectorAggregator.class, agg.getClass());
}
@Test
public void factorizeVectorForComplexShouldReturnObjectVectorAggregator()
{
mockType(ValueType.COMPLEX);
VectorAggregator agg = target.factorizeVector(selectorFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceObjectVectorAggregator.class, agg.getClass());
}
@Test
public void factorizeBufferedForComplexShouldReturnObjectVectorAggregator()
{
mockType(ValueType.COMPLEX);
BufferAggregator agg = target.factorizeBuffered(metricFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceBufferAggregator.ObjectVarianceAggregator.class, agg.getClass());
}
@Test
public void factorizeForComplexShouldReturnObjectVectorAggregator()
{
mockType(ValueType.COMPLEX);
Aggregator agg = target.factorize(metricFactory);
Assert.assertNotNull(agg);
Assert.assertEquals(VarianceAggregator.ObjectVarianceAggregator.class, agg.getClass());
}
@Test(expected = IAE.class)
public void factorizeVectorForUnknownColumnShouldThrowIAE()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, UNKNOWN);
target.factorizeVector(selectorFactory);
}
@Test(expected = IAE.class)
public void factorizeBufferedForUnknownColumnShouldThrowIAE()
{
target = new VarianceAggregatorFactory(NAME, FIELD_NAME, null, UNKNOWN);
target.factorizeBuffered(metricFactory);
}
@Test
public void equalsContract()
{
EqualsVerifier.forClass(VarianceAggregatorFactory.class)
.usingGetClass()
.verify();
}
private void mockType(ValueType type)
{
Mockito.doReturn(capabilities).when(selectorFactory).getColumnCapabilities(FIELD_NAME);
Mockito.doReturn(capabilities).when(metricFactory).getColumnCapabilities(FIELD_NAME);
Mockito.doReturn(type).when(capabilities).getType();
}
}

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class VarianceDoubleVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final int START_ROW = 1;
private static final int POSITION = 2;
private static final int UNINIT_POSITION = 512;
private static final double EPSILON = 1e-10;
private static final double[] VALUES = new double[]{7.8d, 11, 23.67, 60, 123};
private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
@Mock
private VectorValueSelector selector;
private ByteBuffer buf;
private VarianceDoubleVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getDoubleVector();
target = new VarianceDoubleVectorAggregator(selector);
clearBufferForPositions(0, POSITION);
}
@Test
public void initValueShouldInitZero()
{
target.init(buf, UNINIT_POSITION);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
@Test
public void aggregate()
{
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(VALUES.length - START_ROW, collector.count);
Assert.assertEquals(217.67, collector.sum, EPSILON);
Assert.assertEquals(7565.211675, collector.nvariance, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(
VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
collector.count
);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 217.67 : 134, collector.sum, EPSILON);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 7565.211675 : 6272, collector.nvariance, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRowsAndNulls()
{
mockNullsVector();
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
Assert.assertEquals(isNull ? 0 : 1, collector.count);
Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void getShouldReturnAllZeros()
{
VarianceAggregatorCollector collector = target.get(buf, POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
VarianceBufferAggregator.doInit(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class VarianceFloatVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final int START_ROW = 1;
private static final int POSITION = 2;
private static final int UNINIT_POSITION = 512;
private static final double EPSILON = 1e-8;
private static final float[] VALUES = new float[]{7.8F, 11, 23.67F, 60, 123};
private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
@Mock
private VectorValueSelector selector;
private ByteBuffer buf;
private VarianceFloatVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getFloatVector();
target = new VarianceFloatVectorAggregator(selector);
clearBufferForPositions(0, POSITION);
}
@Test
public void initValueShouldInitZero()
{
target.init(buf, UNINIT_POSITION);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
@Test
public void aggregate()
{
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(VALUES.length - START_ROW, collector.count);
Assert.assertEquals(217.67000007, collector.sum, EPSILON);
Assert.assertEquals(7565.2116703, collector.nvariance, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(
VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
collector.count
);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 217.67000007 : 134, collector.sum, EPSILON);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 7565.2116703 : 6272, collector.nvariance, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRowsAndNulls()
{
mockNullsVector();
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
Assert.assertEquals(isNull ? 0 : 1, collector.count);
Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void getShouldReturnAllZeros()
{
VarianceAggregatorCollector collector = target.get(buf, POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
VarianceBufferAggregator.doInit(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -20,6 +20,7 @@
package org.apache.druid.query.aggregation.variance; package org.apache.druid.query.aggregation.variance;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.data.input.Row; import org.apache.druid.data.input.Row;
import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.PeriodGranularity; import org.apache.druid.java.util.common.granularity.PeriodGranularity;
@ -63,14 +64,12 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
private final QueryRunner<Row> runner; private final QueryRunner<Row> runner;
private final GroupByQueryRunnerFactory factory; private final GroupByQueryRunnerFactory factory;
private final String testName; private final String testName;
private final GroupByQuery.Builder queryBuilder;
@Parameterized.Parameters(name = "{0}") @Parameterized.Parameters(name = "{0}")
public static Collection<Object[]> constructorFeeder() public static Collection<Object[]> constructorFeeder()
{ {
// Use GroupByQueryRunnerTest's constructorFeeder, but remove vectorized tests, since this aggregator
// can't vectorize yet.
return GroupByQueryRunnerTest.constructorFeeder().stream() return GroupByQueryRunnerTest.constructorFeeder().stream()
.filter(constructor -> !((boolean) constructor[4]) /* !vectorize */)
.map( .map(
constructor -> constructor ->
new Object[]{ new Object[]{
@ -94,13 +93,14 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
this.config = config; this.config = config;
this.factory = factory; this.factory = factory;
this.runner = factory.mergeRunners(Execs.directExecutor(), ImmutableList.of(runner)); this.runner = factory.mergeRunners(Execs.directExecutor(), ImmutableList.of(runner));
this.queryBuilder = GroupByQuery.builder()
.setContext(ImmutableMap.of("vectorize", config.isVectorize()));
} }
@Test @Test
public void testGroupByVarianceOnly() public void testGroupByVarianceOnly()
{ {
GroupByQuery query = GroupByQuery GroupByQuery query = queryBuilder
.builder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias")) .setDimensions(new DefaultDimensionSpec("quality", "alias"))
@ -141,8 +141,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
@Test @Test
public void testGroupBy() public void testGroupBy()
{ {
GroupByQuery query = GroupByQuery GroupByQuery query = queryBuilder
.builder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias")) .setDimensions(new DefaultDimensionSpec("quality", "alias"))
@ -191,8 +190,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
new String[]{"alias", "rows", "index", "index_var", "index_stddev"} new String[]{"alias", "rows", "index", "index_var", "index_stddev"}
); );
GroupByQuery query = GroupByQuery GroupByQuery query = queryBuilder
.builder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setInterval("2011-04-02/2011-04-04") .setInterval("2011-04-02/2011-04-04")
.setDimensions(new DefaultDimensionSpec("quality", "alias")) .setDimensions(new DefaultDimensionSpec("quality", "alias"))
@ -244,8 +242,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
public void testGroupByZtestPostAgg() public void testGroupByZtestPostAgg()
{ {
// test postaggs from 'teststats' package in here since we've already gone to the trouble of setting up the test // test postaggs from 'teststats' package in here since we've already gone to the trouble of setting up the test
GroupByQuery query = GroupByQuery GroupByQuery query = queryBuilder
.builder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias")) .setDimensions(new DefaultDimensionSpec("quality", "alias"))
@ -286,8 +283,7 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
public void testGroupByTestPvalueZscorePostAgg() public void testGroupByTestPvalueZscorePostAgg()
{ {
// test postaggs from 'teststats' package in here since we've already gone to the trouble of setting up the test // test postaggs from 'teststats' package in here since we've already gone to the trouble of setting up the test
GroupByQuery query = GroupByQuery GroupByQuery query = queryBuilder
.builder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE) .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias")) .setDimensions(new DefaultDimensionSpec("quality", "alias"))
@ -308,7 +304,14 @@ public class VarianceGroupByQueryTest extends InitializedNullHandlingTest
.build(); .build();
VarianceTestHelper.RowBuilder builder = VarianceTestHelper.RowBuilder builder =
new VarianceTestHelper.RowBuilder(new String[]{"alias", "rows", "idx", "index_stddev", "index_var", "pvalueZscore"}); new VarianceTestHelper.RowBuilder(new String[]{
"alias",
"rows",
"idx",
"index_stddev",
"index_var",
"pvalueZscore"
});
List<ResultRow> expectedResults = builder List<ResultRow> expectedResults = builder
.add("2011-04-01", "automotive", 1L, 135.0, 0.0, 0.0, 1.0) .add("2011-04-01", "automotive", 1L, 135.0, 0.0, 0.0, 1.0)

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class VarianceLongVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final int START_ROW = 1;
private static final int POSITION = 2;
private static final int UNINIT_POSITION = 512;
private static final double EPSILON = 1e-10;
private static final long[] VALUES = new long[]{7, 11, 23, 60, 123};
private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
@Mock
private VectorValueSelector selector;
private ByteBuffer buf;
private VarianceLongVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getLongVector();
target = new VarianceLongVectorAggregator(selector);
clearBufferForPositions(0, POSITION);
}
@Test
public void initValueShouldInitZero()
{
target.init(buf, UNINIT_POSITION);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
@Test
public void aggregate()
{
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(VALUES.length - START_ROW, collector.count);
Assert.assertEquals(217, collector.sum, EPSILON);
Assert.assertEquals(7606.75, collector.nvariance, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(
VALUES.length - START_ROW - (NullHandling.replaceWithDefault() ? 0 : 2),
collector.count
);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 217 : 134, collector.sum, EPSILON);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 7606.75 : 6272, collector.nvariance, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[i], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(1, collector.count);
Assert.assertEquals(VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void aggregateBatchWithRowsAndNulls()
{
mockNullsVector();
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
boolean isNull = !NullHandling.replaceWithDefault() && NULLS[rows[i]];
Assert.assertEquals(isNull ? 0 : 1, collector.count);
Assert.assertEquals(isNull ? 0 : VALUES[rows[i]], collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
}
@Test
public void getShouldReturnAllZeros()
{
VarianceAggregatorCollector collector = target.get(buf, POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
VarianceBufferAggregator.doInit(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class VarianceObjectVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final int START_ROW = 1;
private static final int POSITION = 2;
private static final int UNINIT_POSITION = 512;
private static final double EPSILON = 1e-10;
private static final VarianceAggregatorCollector[] VALUES = new VarianceAggregatorCollector[]{
new VarianceAggregatorCollector(1, 7.8, 0),
new VarianceAggregatorCollector(1, 11, 0),
new VarianceAggregatorCollector(1, 23.67, 0),
null,
new VarianceAggregatorCollector(2, 183, 1984.5)
};
private static final boolean[] NULLS = new boolean[]{false, false, true, true, false};
@Mock
private VectorObjectSelector selector;
private ByteBuffer buf;
private VarianceObjectVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getObjectVector();
target = new VarianceObjectVectorAggregator(selector);
clearBufferForPositions(0, POSITION);
}
@Test
public void initValueShouldInitZero()
{
target.init(buf, UNINIT_POSITION);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, UNINIT_POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
@Test
public void aggregate()
{
target.aggregate(buf, POSITION, START_ROW, VALUES.length);
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(buf, POSITION);
Assert.assertEquals(4, collector.count);
Assert.assertEquals(217.67, collector.sum, EPSILON);
Assert.assertEquals(7565.211675, collector.nvariance, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
Assert.assertEquals(VALUES[i], collector);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
VarianceAggregatorCollector collector = VarianceBufferAggregator.getVarianceCollector(
buf,
positions[i] + positionOffset
);
VarianceAggregatorCollector expectedCollector = VALUES[rows[i]];
Assert.assertEquals(expectedCollector == null ? new VarianceAggregatorCollector() : expectedCollector, collector);
}
}
@Test
public void getShouldReturnAllZeros()
{
VarianceAggregatorCollector collector = target.get(buf, POSITION);
Assert.assertEquals(0, collector.count);
Assert.assertEquals(0, collector.sum, EPSILON);
Assert.assertEquals(0, collector.nvariance, EPSILON);
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
VarianceBufferAggregator.doInit(buf, offset + position);
}
}
}

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.variance; package org.apache.druid.query.aggregation.variance;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.query.Druids; import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryPlus;
@ -46,31 +47,32 @@ public class VarianceTimeseriesQueryTest extends InitializedNullHandlingTest
@Parameterized.Parameters(name = "{0}:descending={1}") @Parameterized.Parameters(name = "{0}:descending={1}")
public static Iterable<Object[]> constructorFeeder() public static Iterable<Object[]> constructorFeeder()
{ {
// Use TimeseriesQueryRunnerTest's constructorFeeder, but remove vectorized tests, since this aggregator
// can't vectorize yet.
return StreamSupport.stream(TimeseriesQueryRunnerTest.constructorFeeder().spliterator(), false) return StreamSupport.stream(TimeseriesQueryRunnerTest.constructorFeeder().spliterator(), false)
.filter(constructor -> !((boolean) constructor[2]) /* !vectorize */) .map(constructor -> new Object[]{constructor[0], constructor[1], constructor[2], constructor[3]})
.map(constructor -> new Object[]{constructor[0], constructor[1], constructor[3]})
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private final QueryRunner runner; private final QueryRunner runner;
private final boolean descending; private final boolean descending;
private final Druids.TimeseriesQueryBuilder queryBuilder;
public VarianceTimeseriesQueryTest( public VarianceTimeseriesQueryTest(
QueryRunner runner, QueryRunner runner,
boolean descending, boolean descending,
boolean vectorize,
List<AggregatorFactory> aggregatorFactories List<AggregatorFactory> aggregatorFactories
) )
{ {
this.runner = runner; this.runner = runner;
this.descending = descending; this.descending = descending;
this.queryBuilder = Druids.newTimeseriesQueryBuilder()
.context(ImmutableMap.of("vectorize", vectorize ? "force" : "false"));
} }
@Test @Test
public void testTimeseriesWithNullFilterOnNonExistentDimension() public void testTimeseriesWithNullFilterOnNonExistentDimension()
{ {
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() TimeseriesQuery query = queryBuilder
.dataSource(QueryRunnerTestHelper.DATA_SOURCE) .dataSource(QueryRunnerTestHelper.DATA_SOURCE)
.granularity(QueryRunnerTestHelper.DAY_GRAN) .granularity(QueryRunnerTestHelper.DAY_GRAN)
.filters("bobby", null) .filters("bobby", null)

View File

@ -411,7 +411,8 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
ImmutableList.of( ImmutableList.of(
new StandardDeviationPostAggregator("a0", "a0:agg", "population"), new StandardDeviationPostAggregator("a0", "a0:agg", "population"),
new StandardDeviationPostAggregator("a1", "a1:agg", "population"), new StandardDeviationPostAggregator("a1", "a1:agg", "population"),
new StandardDeviationPostAggregator("a2", "a2:agg", "population")) new StandardDeviationPostAggregator("a2", "a2:agg", "population")
)
) )
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT) .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(), .build(),
@ -561,14 +562,14 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
).toList(); ).toList();
List<Object[]> expectedResults = NullHandling.sqlCompatible() List<Object[]> expectedResults = NullHandling.sqlCompatible()
? ImmutableList.of( ? ImmutableList.of(
new Object[] {"a", 0f}, new Object[]{"a", 0f},
new Object[] {null, 0f}, new Object[]{null, 0f},
new Object[] {"", 0f}, new Object[]{"", 0f},
new Object[] {"abc", null} new Object[]{"abc", null}
) : ImmutableList.of( ) : ImmutableList.of(
new Object[] {"a", 0.5f}, new Object[]{"a", 0.5f},
new Object[] {"", 0.0033333334f}, new Object[]{"", 0.0033333334f},
new Object[] {"abc", 0f} new Object[]{"abc", 0f}
); );
assertResultsEquals(expectedResults, results); assertResultsEquals(expectedResults, results);
@ -622,7 +623,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
Arrays.asList( Arrays.asList(
QueryRunnerTestHelper.ROWS_COUNT, QueryRunnerTestHelper.ROWS_COUNT,
QueryRunnerTestHelper.INDEX_DOUBLE_SUM, QueryRunnerTestHelper.INDEX_DOUBLE_SUM,
new VarianceAggregatorFactory("variance", "index") new VarianceAggregatorFactory("variance", "index", null, null)
) )
) )
.descending(true) .descending(true)
@ -648,9 +649,18 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
{ {
Assert.assertEquals(expectedResults.size(), results.size()); Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) { for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i)); Object[] expectedResult = expectedResults.get(i);
Object[] result = results.get(i);
Assert.assertEquals(expectedResult.length, result.length);
for (int j = 0; j < expectedResult.length; j++) {
if (expectedResult[j] instanceof Float) {
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10);
} else if (expectedResult[j] instanceof Double) {
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10);
} else {
Assert.assertEquals(expectedResult[j], result[j]);
}
}
} }
} }
} }