Support var_pop, var_samp, stddev_pop and stddev_samp etc in sql (#7801)

* support var_pop, stddev_pop etc in sql

* fix sql compatible

* rebase on master

* update doc
This commit is contained in:
Xue Yu 2019-06-11 00:40:09 +08:00 committed by Fangjin Yang
parent c612ddc0f4
commit ce591d1457
8 changed files with 812 additions and 21 deletions

View File

@ -129,6 +129,13 @@ Only the COUNT aggregation can accept DISTINCT.
|`APPROX_QUANTILE_DS(expr, probability, [k])`|Computes approximate quantiles on numeric or [Quantiles sketch](../development/extensions-core/datasketches-quantiles.html) exprs. The "probability" should be between 0 and 1 (exclusive). The `k` parameter is described in the Quantiles sketch documentation. The [DataSketches extension](../development/extensions-core/datasketches-extension.html) must be loaded to use this function.|
|`APPROX_QUANTILE_FIXED_BUCKETS(expr, probability, numBuckets, lowerLimit, upperLimit, [outlierHandlingMode])`|Computes approximate quantiles on numeric or [fixed buckets histogram](../development/extensions-core/approximate-histograms.html#fixed-buckets-histogram) exprs. The "probability" should be between 0 and 1 (exclusive). The `numBuckets`, `lowerLimit`, `upperLimit`, and `outlierHandlingMode` parameters are described in the fixed buckets histogram documentation. The [approximate histogram extension](../development/extensions-core/approximate-histograms.html) must be loaded to use this function.|
|`BLOOM_FILTER(expr, numEntries)`|Computes a bloom filter from values produced by `expr`, with `numEntries` maximum number of distinct values before false positve rate increases. See [bloom filter extension](../development/extensions-core/bloom-filter.html) documentation for additional details.|
|`VAR_POP(expr)`|Computes variance population of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|`VAR_SAMP(expr)`|Computes variance sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|`VARIANCE(expr)`|Computes variance sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|`STDDEV_POP(expr)`|Computes standard deviation population of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|`STDDEV_SAMP(expr)`|Computes standard deviation sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|`STDDEV(expr)`|Computes standard deviation sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.html#approx).

View File

@ -40,6 +40,12 @@
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-sql</artifactId>
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
@ -53,6 +59,15 @@
<scope>test</scope>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-server</artifactId>
<version>${project.parent.version}</version>
<scope>test</scope>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-processing</artifactId>
@ -60,6 +75,15 @@
<scope>test</scope>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-sql</artifactId>
<version>${project.parent.version}</version>
<scope>test</scope>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>

View File

@ -30,7 +30,9 @@ import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregat
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
import org.apache.druid.query.aggregation.variance.VarianceFoldingAggregatorFactory;
import org.apache.druid.query.aggregation.variance.VarianceSerde;
import org.apache.druid.query.aggregation.variance.sql.BaseVarianceSqlAggregator;
import org.apache.druid.segment.serde.ComplexMetrics;
import org.apache.druid.sql.guice.SqlBindings;
import java.util.List;
@ -55,6 +57,15 @@ public class DruidStatsModule implements DruidModule
@Override
public void configure(Binder binder)
{
if (binder != null) {
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarPopSqlAggregator.class);
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarSampSqlAggregator.class);
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarianceSqlAggregator.class);
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevPopSqlAggregator.class);
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevSampSqlAggregator.class);
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevSqlAggregator.class);
}
if (ComplexMetrics.getSerdeForType("variance") == null) {
ComplexMetrics.registerSerde("variance", new VarianceSerde());
}

View File

@ -32,6 +32,7 @@ import org.apache.druid.query.cache.CacheKeyBuilder;
import java.util.Comparator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/**
@ -121,4 +122,32 @@ public class StandardDeviationPostAggregator implements PostAggregator
.appendBoolean(isVariancePop)
.build();
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
StandardDeviationPostAggregator that = (StandardDeviationPostAggregator) o;
if (!Objects.equals(name, that.name)) {
return false;
}
if (!Objects.equals(fieldName, that.fieldName)) {
return false;
}
if (!Objects.equals(estimator, that.estimator)) {
return false;
}
if (!Objects.equals(isVariancePop, that.isVariancePop)) {
return false;
}
return true;
}
}

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.segment.BaseFloatColumnValueSelector;
import org.apache.druid.segment.BaseLongColumnValueSelector;
@ -76,7 +77,9 @@ public abstract class VarianceAggregator implements Aggregator
@Override
public void aggregate()
{
holder.add(selector.getFloat());
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
holder.add(selector.getFloat());
}
}
}
@ -93,7 +96,9 @@ public abstract class VarianceAggregator implements Aggregator
@Override
public void aggregate()
{
holder.add(selector.getLong());
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
holder.add(selector.getLong());
}
}
}

View File

@ -20,12 +20,12 @@
package org.apache.druid.query.aggregation.variance;
import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseFloatColumnValueSelector;
import org.apache.druid.segment.BaseLongColumnValueSelector;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
import java.nio.ByteBuffer;
/**
@ -89,15 +89,17 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
@Override
public void aggregate(ByteBuffer buf, int position)
{
float v = selector.getFloat();
long count = buf.getLong(position + COUNT_OFFSET) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v;
buf.putLong(position, count);
buf.putDouble(position + SUM_OFFSET, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
buf.putDouble(position + NVARIANCE_OFFSET, variance);
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
float v = selector.getFloat();
long count = buf.getLong(position + COUNT_OFFSET) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v;
buf.putLong(position, count);
buf.putDouble(position + SUM_OFFSET, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
buf.putDouble(position + NVARIANCE_OFFSET, variance);
}
}
}
@ -120,15 +122,17 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
@Override
public void aggregate(ByteBuffer buf, int position)
{
long v = selector.getLong();
long count = buf.getLong(position + COUNT_OFFSET) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v;
buf.putLong(position, count);
buf.putDouble(position + SUM_OFFSET, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
buf.putDouble(position + NVARIANCE_OFFSET, variance);
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
long v = selector.getLong();
long count = buf.getLong(position + COUNT_OFFSET) + 1;
double sum = buf.getDouble(position + SUM_OFFSET) + v;
buf.putLong(position, count);
buf.putDouble(position + SUM_OFFSET, sum);
if (count > 1) {
double t = count * v - sum;
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
buf.putDouble(position + NVARIANCE_OFFSET, variance);
}
}
}

View File

@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance.sql;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.table.RowSignature;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public abstract class BaseVarianceSqlAggregator implements SqlAggregator
{
@Nullable
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
final RexNode inputOperand = Expressions.fromFieldAccess(
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final DruidExpression input = Expressions.toDruidExpression(
plannerContext,
rowSignature,
inputOperand
);
if (input == null) {
return null;
}
final AggregatorFactory aggregatorFactory;
final SqlTypeName sqlTypeName = inputOperand.getType().getSqlTypeName();
final ValueType inputType = Calcites.getValueTypeForSqlTypeName(sqlTypeName);
final List<VirtualColumn> virtualColumns = new ArrayList<>();
final DimensionSpec dimensionSpec;
final String aggName = StringUtils.format("%s:agg", name);
final SqlAggFunction func = calciteFunction();
final String estimator;
final String inputTypeName;
PostAggregator postAggregator = null;
if (input.isSimpleExtraction()) {
dimensionSpec = input.getSimpleExtraction().toDimensionSpec(null, inputType);
} else {
VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, sqlTypeName);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
virtualColumns.add(virtualColumn);
}
if (inputType == ValueType.LONG) {
inputTypeName = "long";
} else if (inputType == ValueType.FLOAT || inputType == ValueType.DOUBLE) {
inputTypeName = "float";
} else {
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType);
}
if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
estimator = "population";
} else {
estimator = "sample";
}
aggregatorFactory = new VarianceAggregatorFactory(
aggName,
dimensionSpec.getDimension(),
estimator,
inputTypeName
);
if (func == SqlStdOperatorTable.STDDEV_POP
|| func == SqlStdOperatorTable.STDDEV_SAMP
|| func == SqlStdOperatorTable.STDDEV) {
postAggregator = new StandardDeviationPostAggregator(
name,
aggregatorFactory.getName(),
estimator);
}
return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory),
postAggregator
);
}
public static class VarPopSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.VAR_POP;
}
}
public static class VarSampSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.VAR_SAMP;
}
}
public static class VarianceSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.VARIANCE;
}
}
public static class StdDevPopSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.STDDEV_POP;
}
}
public static class StdDevSampSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.STDDEV_SAMP;
}
}
public static class StdDevSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.STDDEV;
}
}
}

View File

@ -0,0 +1,518 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.variance.sql;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.data.input.impl.DimensionSchema;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.data.input.impl.DoubleDimensionSchema;
import org.apache.druid.data.input.impl.FloatDimensionSchema;
import org.apache.druid.data.input.impl.InputRowParser;
import org.apache.druid.data.input.impl.LongDimensionSchema;
import org.apache.druid.data.input.impl.MapInputRowParser;
import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
import org.apache.druid.data.input.impl.TimestampSpec;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.server.security.AuthTestUtils;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.sql.SqlLifecycle;
import org.apache.druid.sql.SqlLifecycleFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerFactory;
import org.apache.druid.sql.calcite.schema.DruidSchema;
import org.apache.druid.sql.calcite.schema.SystemSchema;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.QueryLogHook;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.IOException;
import java.util.List;
public class VarianceSqlAggregatorTest
{
private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final String DATA_SOURCE = "numfoo";
private static QueryRunnerFactoryConglomerate conglomerate;
private static Closer resourceCloser;
@BeforeClass
public static void setUpClass()
{
final Pair<QueryRunnerFactoryConglomerate, Closer> conglomerateCloserPair = CalciteTests
.createQueryRunnerFactoryConglomerate();
conglomerate = conglomerateCloserPair.lhs;
resourceCloser = conglomerateCloserPair.rhs;
}
@AfterClass
public static void tearDownClass() throws IOException
{
resourceCloser.close();
}
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule
public QueryLogHook queryLogHook = QueryLogHook.create();
private SpecificSegmentsQuerySegmentWalker walker;
private SqlLifecycleFactory sqlLifecycleFactory;
@Before
public void setUp() throws Exception
{
InputRowParser parser = new MapInputRowParser(
new TimeAndDimsParseSpec(
new TimestampSpec("t", "iso", null),
new DimensionsSpec(
ImmutableList.<DimensionSchema>builder()
.addAll(DimensionsSpec.getDefaultSchemas(ImmutableList.of("dim1", "dim2", "dim3")))
.add(new DoubleDimensionSchema("d1"))
.add(new FloatDimensionSchema("f1"))
.add(new LongDimensionSchema("l1"))
.build(),
null,
null
)
));
final QueryableIndex index =
IndexBuilder.create()
.tmpDir(temporaryFolder.newFolder())
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(
new IncrementalIndexSchema.Builder()
.withMetrics(
new CountAggregatorFactory("cnt"),
new DoubleSumAggregatorFactory("m1", "m1")
)
.withDimensionsSpec(parser)
.withRollup(false)
.build()
)
.rows(CalciteTests.ROWS1_WITH_NUMERIC_DIMS)
.buildMMappedIndex();
walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
DataSegment.builder()
.dataSource(DATA_SOURCE)
.interval(index.getDataInterval())
.version("1")
.shardSpec(new LinearShardSpec(0))
.build(),
index
);
final PlannerConfig plannerConfig = new PlannerConfig();
final DruidSchema druidSchema = CalciteTests.createMockSchema(conglomerate, walker, plannerConfig);
final SystemSchema systemSchema = CalciteTests.createMockSystemSchema(druidSchema, walker, plannerConfig);
final DruidOperatorTable operatorTable = new DruidOperatorTable(
ImmutableSet.of(
new BaseVarianceSqlAggregator.VarPopSqlAggregator(),
new BaseVarianceSqlAggregator.VarSampSqlAggregator(),
new BaseVarianceSqlAggregator.VarianceSqlAggregator(),
new BaseVarianceSqlAggregator.StdDevPopSqlAggregator(),
new BaseVarianceSqlAggregator.StdDevSampSqlAggregator(),
new BaseVarianceSqlAggregator.StdDevSqlAggregator()
),
ImmutableSet.of()
);
sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(
new PlannerFactory(
druidSchema,
systemSchema,
CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
operatorTable,
CalciteTests.createExprMacroTable(),
plannerConfig,
AuthTestUtils.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
)
);
}
@After
public void tearDown() throws Exception
{
walker.close();
walker = null;
}
public void addToHolder(VarianceAggregatorCollector holder, Object raw)
{
addToHolder(holder, raw, 1);
}
public void addToHolder(VarianceAggregatorCollector holder, Object raw, int multiply)
{
if (raw != null) {
if (raw instanceof Double) {
double v = ((Double) raw).doubleValue() * multiply;
holder.add((float) v);
} else if (raw instanceof Float) {
float v = ((Float) raw).floatValue() * multiply;
holder.add(v);
} else if (raw instanceof Long) {
long v = ((Long) raw).longValue() * multiply;
holder.add(v);
} else if (raw instanceof Integer) {
int v = ((Integer) raw).intValue() * multiply;
holder.add(v);
}
} else {
if (NullHandling.replaceWithDefault()) {
holder.add(0.0f);
}
}
}
@Test
public void testVarPop() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "VAR_POP(d1),\n"
+ "VAR_POP(f1),\n"
+ "VAR_POP(l1)\n"
+ "FROM numfoo";
final List<Object[]> results =
sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
Object raw1 = row.getRaw("d1");
Object raw2 = row.getRaw("f1");
Object raw3 = row.getRaw("l1");
addToHolder(holder1, raw1);
addToHolder(holder2, raw2);
addToHolder(holder3, raw3);
}
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
holder1.getVariance(true),
(float) holder2.getVariance(true),
(long) holder3.getVariance(true),
}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
)
)
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@Test
public void testVarSamp() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "VAR_SAMP(d1),\n"
+ "VAR_SAMP(f1),\n"
+ "VAR_SAMP(l1)\n"
+ "FROM numfoo";
final List<Object[]> results =
sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
Object raw1 = row.getRaw("d1");
Object raw2 = row.getRaw("f1");
Object raw3 = row.getRaw("l1");
addToHolder(holder1, raw1);
addToHolder(holder2, raw2);
addToHolder(holder3, raw3);
}
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
holder1.getVariance(false),
(float) holder2.getVariance(false),
(long) holder3.getVariance(false),
}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
)
)
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@Test
public void testStdDevPop() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "STDDEV_POP(d1),\n"
+ "STDDEV_POP(f1),\n"
+ "STDDEV_POP(l1)\n"
+ "FROM numfoo";
final List<Object[]> results =
sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
Object raw1 = row.getRaw("d1");
Object raw2 = row.getRaw("f1");
Object raw3 = row.getRaw("l1");
addToHolder(holder1, raw1);
addToHolder(holder2, raw2);
addToHolder(holder3, raw3);
}
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
Math.sqrt(holder1.getVariance(true)),
(float) Math.sqrt(holder2.getVariance(true)),
(long) Math.sqrt(holder3.getVariance(true)),
}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
)
)
.postAggregators(
ImmutableList.of(
new StandardDeviationPostAggregator("a0", "a0:agg", "population"),
new StandardDeviationPostAggregator("a1", "a1:agg", "population"),
new StandardDeviationPostAggregator("a2", "a2:agg", "population"))
)
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@Test
public void testStdDevSamp() throws Exception
{
queryLogHook.clearRecordedQueries();
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "STDDEV_SAMP(d1),\n"
+ "STDDEV_SAMP(f1),\n"
+ "STDDEV_SAMP(l1)\n"
+ "FROM numfoo";
final List<Object[]> results =
sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
Object raw1 = row.getRaw("d1");
Object raw2 = row.getRaw("f1");
Object raw3 = row.getRaw("l1");
addToHolder(holder1, raw1);
addToHolder(holder2, raw2);
addToHolder(holder3, raw3);
}
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
Math.sqrt(holder1.getVariance(false)),
(float) Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
)
)
.postAggregators(
new StandardDeviationPostAggregator("a0", "a0:agg", "sample"),
new StandardDeviationPostAggregator("a1", "a1:agg", "sample"),
new StandardDeviationPostAggregator("a2", "a2:agg", "sample")
)
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@Test
public void testStdDevWithVirtualColumns() throws Exception
{
queryLogHook.clearRecordedQueries();
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "STDDEV(d1*7),\n"
+ "STDDEV(f1*7),\n"
+ "STDDEV(l1*7)\n"
+ "FROM numfoo";
final List<Object[]> results =
sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
Object raw1 = row.getRaw("d1");
Object raw2 = row.getRaw("f1");
Object raw3 = row.getRaw("l1");
addToHolder(holder1, raw1, 7);
addToHolder(holder2, raw2, 7);
addToHolder(holder3, raw3, 7);
}
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
Math.sqrt(holder1.getVariance(false)),
(float) Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.virtualColumns(
BaseCalciteQueryTest.expressionVirtualColumn("v0", "(\"d1\" * 7)", ValueType.DOUBLE),
BaseCalciteQueryTest.expressionVirtualColumn("v1", "(\"f1\" * 7)", ValueType.FLOAT),
BaseCalciteQueryTest.expressionVirtualColumn("v2", "(\"l1\" * 7)", ValueType.LONG)
)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "v0", "sample", "float"),
new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"),
new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long")
)
)
.postAggregators(
new StandardDeviationPostAggregator("a0", "a0:agg", "sample"),
new StandardDeviationPostAggregator("a1", "a1:agg", "sample"),
new StandardDeviationPostAggregator("a2", "a2:agg", "sample")
)
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
}