mirror of https://github.com/apache/druid.git
Support var_pop, var_samp, stddev_pop and stddev_samp etc in sql (#7801)
* support var_pop, stddev_pop etc in sql * fix sql compatible * rebase on master * update doc
This commit is contained in:
parent
c612ddc0f4
commit
ce591d1457
|
@ -129,6 +129,13 @@ Only the COUNT aggregation can accept DISTINCT.
|
|||
|`APPROX_QUANTILE_DS(expr, probability, [k])`|Computes approximate quantiles on numeric or [Quantiles sketch](../development/extensions-core/datasketches-quantiles.html) exprs. The "probability" should be between 0 and 1 (exclusive). The `k` parameter is described in the Quantiles sketch documentation. The [DataSketches extension](../development/extensions-core/datasketches-extension.html) must be loaded to use this function.|
|
||||
|`APPROX_QUANTILE_FIXED_BUCKETS(expr, probability, numBuckets, lowerLimit, upperLimit, [outlierHandlingMode])`|Computes approximate quantiles on numeric or [fixed buckets histogram](../development/extensions-core/approximate-histograms.html#fixed-buckets-histogram) exprs. The "probability" should be between 0 and 1 (exclusive). The `numBuckets`, `lowerLimit`, `upperLimit`, and `outlierHandlingMode` parameters are described in the fixed buckets histogram documentation. The [approximate histogram extension](../development/extensions-core/approximate-histograms.html) must be loaded to use this function.|
|
||||
|`BLOOM_FILTER(expr, numEntries)`|Computes a bloom filter from values produced by `expr`, with `numEntries` maximum number of distinct values before false positve rate increases. See [bloom filter extension](../development/extensions-core/bloom-filter.html) documentation for additional details.|
|
||||
|`VAR_POP(expr)`|Computes variance population of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|
||||
|`VAR_SAMP(expr)`|Computes variance sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|
||||
|`VARIANCE(expr)`|Computes variance sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|
||||
|`STDDEV_POP(expr)`|Computes standard deviation population of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|
||||
|`STDDEV_SAMP(expr)`|Computes standard deviation sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|
||||
|`STDDEV(expr)`|Computes standard deviation sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
|
||||
|
||||
|
||||
For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.html#approx).
|
||||
|
||||
|
|
|
@ -40,6 +40,12 @@
|
|||
<version>${project.parent.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.druid</groupId>
|
||||
<artifactId>druid-sql</artifactId>
|
||||
<version>${project.parent.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-math3</artifactId>
|
||||
|
@ -53,6 +59,15 @@
|
|||
<scope>test</scope>
|
||||
<type>test-jar</type>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.druid</groupId>
|
||||
<artifactId>druid-server</artifactId>
|
||||
<version>${project.parent.version}</version>
|
||||
<scope>test</scope>
|
||||
<type>test-jar</type>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.druid</groupId>
|
||||
<artifactId>druid-processing</artifactId>
|
||||
|
@ -60,6 +75,15 @@
|
|||
<scope>test</scope>
|
||||
<type>test-jar</type>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.druid</groupId>
|
||||
<artifactId>druid-sql</artifactId>
|
||||
<version>${project.parent.version}</version>
|
||||
<scope>test</scope>
|
||||
<type>test-jar</type>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
|
|
|
@ -30,7 +30,9 @@ import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregat
|
|||
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.variance.VarianceFoldingAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.variance.VarianceSerde;
|
||||
import org.apache.druid.query.aggregation.variance.sql.BaseVarianceSqlAggregator;
|
||||
import org.apache.druid.segment.serde.ComplexMetrics;
|
||||
import org.apache.druid.sql.guice.SqlBindings;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -55,6 +57,15 @@ public class DruidStatsModule implements DruidModule
|
|||
@Override
|
||||
public void configure(Binder binder)
|
||||
{
|
||||
if (binder != null) {
|
||||
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarPopSqlAggregator.class);
|
||||
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarSampSqlAggregator.class);
|
||||
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarianceSqlAggregator.class);
|
||||
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevPopSqlAggregator.class);
|
||||
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevSampSqlAggregator.class);
|
||||
SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevSqlAggregator.class);
|
||||
}
|
||||
|
||||
if (ComplexMetrics.getSerdeForType("variance") == null) {
|
||||
ComplexMetrics.registerSerde("variance", new VarianceSerde());
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.druid.query.cache.CacheKeyBuilder;
|
|||
|
||||
import java.util.Comparator;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
|
@ -121,4 +122,32 @@ public class StandardDeviationPostAggregator implements PostAggregator
|
|||
.appendBoolean(isVariancePop)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o)
|
||||
{
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
StandardDeviationPostAggregator that = (StandardDeviationPostAggregator) o;
|
||||
|
||||
if (!Objects.equals(name, that.name)) {
|
||||
return false;
|
||||
}
|
||||
if (!Objects.equals(fieldName, that.fieldName)) {
|
||||
return false;
|
||||
}
|
||||
if (!Objects.equals(estimator, that.estimator)) {
|
||||
return false;
|
||||
}
|
||||
if (!Objects.equals(isVariancePop, that.isVariancePop)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
package org.apache.druid.query.aggregation.variance;
|
||||
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.query.aggregation.Aggregator;
|
||||
import org.apache.druid.segment.BaseFloatColumnValueSelector;
|
||||
import org.apache.druid.segment.BaseLongColumnValueSelector;
|
||||
|
@ -76,7 +77,9 @@ public abstract class VarianceAggregator implements Aggregator
|
|||
@Override
|
||||
public void aggregate()
|
||||
{
|
||||
holder.add(selector.getFloat());
|
||||
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
|
||||
holder.add(selector.getFloat());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -93,7 +96,9 @@ public abstract class VarianceAggregator implements Aggregator
|
|||
@Override
|
||||
public void aggregate()
|
||||
{
|
||||
holder.add(selector.getLong());
|
||||
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
|
||||
holder.add(selector.getLong());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,12 +20,12 @@
|
|||
package org.apache.druid.query.aggregation.variance;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.query.aggregation.BufferAggregator;
|
||||
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
|
||||
import org.apache.druid.segment.BaseFloatColumnValueSelector;
|
||||
import org.apache.druid.segment.BaseLongColumnValueSelector;
|
||||
import org.apache.druid.segment.BaseObjectColumnValueSelector;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
|
@ -89,15 +89,17 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
|
|||
@Override
|
||||
public void aggregate(ByteBuffer buf, int position)
|
||||
{
|
||||
float v = selector.getFloat();
|
||||
long count = buf.getLong(position + COUNT_OFFSET) + 1;
|
||||
double sum = buf.getDouble(position + SUM_OFFSET) + v;
|
||||
buf.putLong(position, count);
|
||||
buf.putDouble(position + SUM_OFFSET, sum);
|
||||
if (count > 1) {
|
||||
double t = count * v - sum;
|
||||
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
|
||||
buf.putDouble(position + NVARIANCE_OFFSET, variance);
|
||||
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
|
||||
float v = selector.getFloat();
|
||||
long count = buf.getLong(position + COUNT_OFFSET) + 1;
|
||||
double sum = buf.getDouble(position + SUM_OFFSET) + v;
|
||||
buf.putLong(position, count);
|
||||
buf.putDouble(position + SUM_OFFSET, sum);
|
||||
if (count > 1) {
|
||||
double t = count * v - sum;
|
||||
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
|
||||
buf.putDouble(position + NVARIANCE_OFFSET, variance);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -120,15 +122,17 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
|
|||
@Override
|
||||
public void aggregate(ByteBuffer buf, int position)
|
||||
{
|
||||
long v = selector.getLong();
|
||||
long count = buf.getLong(position + COUNT_OFFSET) + 1;
|
||||
double sum = buf.getDouble(position + SUM_OFFSET) + v;
|
||||
buf.putLong(position, count);
|
||||
buf.putDouble(position + SUM_OFFSET, sum);
|
||||
if (count > 1) {
|
||||
double t = count * v - sum;
|
||||
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
|
||||
buf.putDouble(position + NVARIANCE_OFFSET, variance);
|
||||
if (NullHandling.replaceWithDefault() || !selector.isNull()) {
|
||||
long v = selector.getLong();
|
||||
long count = buf.getLong(position + COUNT_OFFSET) + 1;
|
||||
double sum = buf.getDouble(position + SUM_OFFSET) + v;
|
||||
buf.putLong(position, count);
|
||||
buf.putDouble(position + SUM_OFFSET, sum);
|
||||
if (count > 1) {
|
||||
double t = count * v - sum;
|
||||
double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
|
||||
buf.putDouble(position + NVARIANCE_OFFSET, variance);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.query.aggregation.variance.sql;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.apache.calcite.rel.core.AggregateCall;
|
||||
import org.apache.calcite.rel.core.Project;
|
||||
import org.apache.calcite.rex.RexBuilder;
|
||||
import org.apache.calcite.rex.RexNode;
|
||||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.druid.java.util.common.IAE;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.PostAggregator;
|
||||
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
|
||||
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
|
||||
import org.apache.druid.query.dimension.DefaultDimensionSpec;
|
||||
import org.apache.druid.query.dimension.DimensionSpec;
|
||||
import org.apache.druid.segment.VirtualColumn;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.expression.Expressions;
|
||||
import org.apache.druid.sql.calcite.planner.Calcites;
|
||||
import org.apache.druid.sql.calcite.planner.PlannerContext;
|
||||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
import org.apache.druid.sql.calcite.table.RowSignature;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
||||
{
|
||||
@Nullable
|
||||
@Override
|
||||
public Aggregation toDruidAggregation(
|
||||
PlannerContext plannerContext,
|
||||
RowSignature rowSignature,
|
||||
VirtualColumnRegistry virtualColumnRegistry,
|
||||
RexBuilder rexBuilder,
|
||||
String name,
|
||||
AggregateCall aggregateCall,
|
||||
Project project,
|
||||
List<Aggregation> existingAggregations,
|
||||
boolean finalizeAggregations
|
||||
)
|
||||
{
|
||||
final RexNode inputOperand = Expressions.fromFieldAccess(
|
||||
rowSignature,
|
||||
project,
|
||||
aggregateCall.getArgList().get(0)
|
||||
);
|
||||
final DruidExpression input = Expressions.toDruidExpression(
|
||||
plannerContext,
|
||||
rowSignature,
|
||||
inputOperand
|
||||
);
|
||||
if (input == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
final AggregatorFactory aggregatorFactory;
|
||||
final SqlTypeName sqlTypeName = inputOperand.getType().getSqlTypeName();
|
||||
final ValueType inputType = Calcites.getValueTypeForSqlTypeName(sqlTypeName);
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
final DimensionSpec dimensionSpec;
|
||||
final String aggName = StringUtils.format("%s:agg", name);
|
||||
final SqlAggFunction func = calciteFunction();
|
||||
final String estimator;
|
||||
final String inputTypeName;
|
||||
PostAggregator postAggregator = null;
|
||||
|
||||
if (input.isSimpleExtraction()) {
|
||||
dimensionSpec = input.getSimpleExtraction().toDimensionSpec(null, inputType);
|
||||
} else {
|
||||
VirtualColumn virtualColumn =
|
||||
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, sqlTypeName);
|
||||
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
|
||||
virtualColumns.add(virtualColumn);
|
||||
}
|
||||
|
||||
if (inputType == ValueType.LONG) {
|
||||
inputTypeName = "long";
|
||||
} else if (inputType == ValueType.FLOAT || inputType == ValueType.DOUBLE) {
|
||||
inputTypeName = "float";
|
||||
} else {
|
||||
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType);
|
||||
}
|
||||
|
||||
if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
|
||||
estimator = "population";
|
||||
} else {
|
||||
estimator = "sample";
|
||||
}
|
||||
|
||||
aggregatorFactory = new VarianceAggregatorFactory(
|
||||
aggName,
|
||||
dimensionSpec.getDimension(),
|
||||
estimator,
|
||||
inputTypeName
|
||||
);
|
||||
|
||||
if (func == SqlStdOperatorTable.STDDEV_POP
|
||||
|| func == SqlStdOperatorTable.STDDEV_SAMP
|
||||
|| func == SqlStdOperatorTable.STDDEV) {
|
||||
postAggregator = new StandardDeviationPostAggregator(
|
||||
name,
|
||||
aggregatorFactory.getName(),
|
||||
estimator);
|
||||
}
|
||||
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
ImmutableList.of(aggregatorFactory),
|
||||
postAggregator
|
||||
);
|
||||
}
|
||||
|
||||
public static class VarPopSqlAggregator extends BaseVarianceSqlAggregator
|
||||
{
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.VAR_POP;
|
||||
}
|
||||
}
|
||||
|
||||
public static class VarSampSqlAggregator extends BaseVarianceSqlAggregator
|
||||
{
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.VAR_SAMP;
|
||||
}
|
||||
}
|
||||
|
||||
public static class VarianceSqlAggregator extends BaseVarianceSqlAggregator
|
||||
{
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.VARIANCE;
|
||||
}
|
||||
}
|
||||
|
||||
public static class StdDevPopSqlAggregator extends BaseVarianceSqlAggregator
|
||||
{
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.STDDEV_POP;
|
||||
}
|
||||
}
|
||||
|
||||
public static class StdDevSampSqlAggregator extends BaseVarianceSqlAggregator
|
||||
{
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.STDDEV_SAMP;
|
||||
}
|
||||
}
|
||||
|
||||
public static class StdDevSqlAggregator extends BaseVarianceSqlAggregator
|
||||
{
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.STDDEV;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,518 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.query.aggregation.variance.sql;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Iterables;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.data.input.InputRow;
|
||||
import org.apache.druid.data.input.impl.DimensionSchema;
|
||||
import org.apache.druid.data.input.impl.DimensionsSpec;
|
||||
import org.apache.druid.data.input.impl.DoubleDimensionSchema;
|
||||
import org.apache.druid.data.input.impl.FloatDimensionSchema;
|
||||
import org.apache.druid.data.input.impl.InputRowParser;
|
||||
import org.apache.druid.data.input.impl.LongDimensionSchema;
|
||||
import org.apache.druid.data.input.impl.MapInputRowParser;
|
||||
import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
|
||||
import org.apache.druid.data.input.impl.TimestampSpec;
|
||||
import org.apache.druid.java.util.common.Pair;
|
||||
import org.apache.druid.java.util.common.granularity.Granularities;
|
||||
import org.apache.druid.java.util.common.io.Closer;
|
||||
import org.apache.druid.query.Druids;
|
||||
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
|
||||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
|
||||
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
|
||||
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
|
||||
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
|
||||
import org.apache.druid.segment.IndexBuilder;
|
||||
import org.apache.druid.segment.QueryableIndex;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
|
||||
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
|
||||
import org.apache.druid.server.security.AuthTestUtils;
|
||||
import org.apache.druid.server.security.AuthenticationResult;
|
||||
import org.apache.druid.sql.SqlLifecycle;
|
||||
import org.apache.druid.sql.SqlLifecycleFactory;
|
||||
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
|
||||
import org.apache.druid.sql.calcite.filtration.Filtration;
|
||||
import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
|
||||
import org.apache.druid.sql.calcite.planner.PlannerConfig;
|
||||
import org.apache.druid.sql.calcite.planner.PlannerFactory;
|
||||
import org.apache.druid.sql.calcite.schema.DruidSchema;
|
||||
import org.apache.druid.sql.calcite.schema.SystemSchema;
|
||||
import org.apache.druid.sql.calcite.util.CalciteTests;
|
||||
import org.apache.druid.sql.calcite.util.QueryLogHook;
|
||||
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
|
||||
import org.apache.druid.timeline.DataSegment;
|
||||
import org.apache.druid.timeline.partition.LinearShardSpec;
|
||||
import org.junit.After;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
public class VarianceSqlAggregatorTest
|
||||
{
|
||||
private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
|
||||
private static final String DATA_SOURCE = "numfoo";
|
||||
|
||||
private static QueryRunnerFactoryConglomerate conglomerate;
|
||||
private static Closer resourceCloser;
|
||||
|
||||
@BeforeClass
|
||||
public static void setUpClass()
|
||||
{
|
||||
final Pair<QueryRunnerFactoryConglomerate, Closer> conglomerateCloserPair = CalciteTests
|
||||
.createQueryRunnerFactoryConglomerate();
|
||||
conglomerate = conglomerateCloserPair.lhs;
|
||||
resourceCloser = conglomerateCloserPair.rhs;
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void tearDownClass() throws IOException
|
||||
{
|
||||
resourceCloser.close();
|
||||
}
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder temporaryFolder = new TemporaryFolder();
|
||||
|
||||
@Rule
|
||||
public QueryLogHook queryLogHook = QueryLogHook.create();
|
||||
|
||||
private SpecificSegmentsQuerySegmentWalker walker;
|
||||
private SqlLifecycleFactory sqlLifecycleFactory;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception
|
||||
{
|
||||
InputRowParser parser = new MapInputRowParser(
|
||||
new TimeAndDimsParseSpec(
|
||||
new TimestampSpec("t", "iso", null),
|
||||
new DimensionsSpec(
|
||||
ImmutableList.<DimensionSchema>builder()
|
||||
.addAll(DimensionsSpec.getDefaultSchemas(ImmutableList.of("dim1", "dim2", "dim3")))
|
||||
.add(new DoubleDimensionSchema("d1"))
|
||||
.add(new FloatDimensionSchema("f1"))
|
||||
.add(new LongDimensionSchema("l1"))
|
||||
.build(),
|
||||
null,
|
||||
null
|
||||
)
|
||||
));
|
||||
|
||||
final QueryableIndex index =
|
||||
IndexBuilder.create()
|
||||
.tmpDir(temporaryFolder.newFolder())
|
||||
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
|
||||
.schema(
|
||||
new IncrementalIndexSchema.Builder()
|
||||
.withMetrics(
|
||||
new CountAggregatorFactory("cnt"),
|
||||
new DoubleSumAggregatorFactory("m1", "m1")
|
||||
)
|
||||
.withDimensionsSpec(parser)
|
||||
.withRollup(false)
|
||||
.build()
|
||||
)
|
||||
.rows(CalciteTests.ROWS1_WITH_NUMERIC_DIMS)
|
||||
.buildMMappedIndex();
|
||||
|
||||
walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
|
||||
DataSegment.builder()
|
||||
.dataSource(DATA_SOURCE)
|
||||
.interval(index.getDataInterval())
|
||||
.version("1")
|
||||
.shardSpec(new LinearShardSpec(0))
|
||||
.build(),
|
||||
index
|
||||
);
|
||||
|
||||
final PlannerConfig plannerConfig = new PlannerConfig();
|
||||
final DruidSchema druidSchema = CalciteTests.createMockSchema(conglomerate, walker, plannerConfig);
|
||||
final SystemSchema systemSchema = CalciteTests.createMockSystemSchema(druidSchema, walker, plannerConfig);
|
||||
final DruidOperatorTable operatorTable = new DruidOperatorTable(
|
||||
ImmutableSet.of(
|
||||
new BaseVarianceSqlAggregator.VarPopSqlAggregator(),
|
||||
new BaseVarianceSqlAggregator.VarSampSqlAggregator(),
|
||||
new BaseVarianceSqlAggregator.VarianceSqlAggregator(),
|
||||
new BaseVarianceSqlAggregator.StdDevPopSqlAggregator(),
|
||||
new BaseVarianceSqlAggregator.StdDevSampSqlAggregator(),
|
||||
new BaseVarianceSqlAggregator.StdDevSqlAggregator()
|
||||
),
|
||||
ImmutableSet.of()
|
||||
);
|
||||
|
||||
sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(
|
||||
new PlannerFactory(
|
||||
druidSchema,
|
||||
systemSchema,
|
||||
CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
|
||||
operatorTable,
|
||||
CalciteTests.createExprMacroTable(),
|
||||
plannerConfig,
|
||||
AuthTestUtils.TEST_AUTHORIZER_MAPPER,
|
||||
CalciteTests.getJsonMapper()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() throws Exception
|
||||
{
|
||||
walker.close();
|
||||
walker = null;
|
||||
}
|
||||
|
||||
public void addToHolder(VarianceAggregatorCollector holder, Object raw)
|
||||
{
|
||||
addToHolder(holder, raw, 1);
|
||||
}
|
||||
|
||||
public void addToHolder(VarianceAggregatorCollector holder, Object raw, int multiply)
|
||||
{
|
||||
if (raw != null) {
|
||||
if (raw instanceof Double) {
|
||||
double v = ((Double) raw).doubleValue() * multiply;
|
||||
holder.add((float) v);
|
||||
} else if (raw instanceof Float) {
|
||||
float v = ((Float) raw).floatValue() * multiply;
|
||||
holder.add(v);
|
||||
} else if (raw instanceof Long) {
|
||||
long v = ((Long) raw).longValue() * multiply;
|
||||
holder.add(v);
|
||||
} else if (raw instanceof Integer) {
|
||||
int v = ((Integer) raw).intValue() * multiply;
|
||||
holder.add(v);
|
||||
}
|
||||
} else {
|
||||
if (NullHandling.replaceWithDefault()) {
|
||||
holder.add(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVarPop() throws Exception
|
||||
{
|
||||
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
final String sql = "SELECT\n"
|
||||
+ "VAR_POP(d1),\n"
|
||||
+ "VAR_POP(f1),\n"
|
||||
+ "VAR_POP(l1)\n"
|
||||
+ "FROM numfoo";
|
||||
|
||||
final List<Object[]> results =
|
||||
sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
|
||||
|
||||
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
|
||||
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
|
||||
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
|
||||
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
|
||||
Object raw1 = row.getRaw("d1");
|
||||
Object raw2 = row.getRaw("f1");
|
||||
Object raw3 = row.getRaw("l1");
|
||||
addToHolder(holder1, raw1);
|
||||
addToHolder(holder2, raw2);
|
||||
addToHolder(holder3, raw3);
|
||||
}
|
||||
|
||||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
holder1.getVariance(true),
|
||||
(float) holder2.getVariance(true),
|
||||
(long) holder3.getVariance(true),
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE3)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
|
||||
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
|
||||
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
|
||||
)
|
||||
)
|
||||
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
|
||||
.build(),
|
||||
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVarSamp() throws Exception
|
||||
{
|
||||
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
final String sql = "SELECT\n"
|
||||
+ "VAR_SAMP(d1),\n"
|
||||
+ "VAR_SAMP(f1),\n"
|
||||
+ "VAR_SAMP(l1)\n"
|
||||
+ "FROM numfoo";
|
||||
|
||||
final List<Object[]> results =
|
||||
sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
|
||||
|
||||
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
|
||||
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
|
||||
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
|
||||
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
|
||||
Object raw1 = row.getRaw("d1");
|
||||
Object raw2 = row.getRaw("f1");
|
||||
Object raw3 = row.getRaw("l1");
|
||||
addToHolder(holder1, raw1);
|
||||
addToHolder(holder2, raw2);
|
||||
addToHolder(holder3, raw3);
|
||||
}
|
||||
|
||||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
holder1.getVariance(false),
|
||||
(float) holder2.getVariance(false),
|
||||
(long) holder3.getVariance(false),
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE3)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
|
||||
)
|
||||
)
|
||||
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
|
||||
.build(),
|
||||
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStdDevPop() throws Exception
|
||||
{
|
||||
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
final String sql = "SELECT\n"
|
||||
+ "STDDEV_POP(d1),\n"
|
||||
+ "STDDEV_POP(f1),\n"
|
||||
+ "STDDEV_POP(l1)\n"
|
||||
+ "FROM numfoo";
|
||||
|
||||
final List<Object[]> results =
|
||||
sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
|
||||
|
||||
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
|
||||
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
|
||||
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
|
||||
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
|
||||
Object raw1 = row.getRaw("d1");
|
||||
Object raw2 = row.getRaw("f1");
|
||||
Object raw3 = row.getRaw("l1");
|
||||
addToHolder(holder1, raw1);
|
||||
addToHolder(holder2, raw2);
|
||||
addToHolder(holder3, raw3);
|
||||
}
|
||||
|
||||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
Math.sqrt(holder1.getVariance(true)),
|
||||
(float) Math.sqrt(holder2.getVariance(true)),
|
||||
(long) Math.sqrt(holder3.getVariance(true)),
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE3)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
|
||||
new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
|
||||
new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
|
||||
)
|
||||
)
|
||||
.postAggregators(
|
||||
ImmutableList.of(
|
||||
new StandardDeviationPostAggregator("a0", "a0:agg", "population"),
|
||||
new StandardDeviationPostAggregator("a1", "a1:agg", "population"),
|
||||
new StandardDeviationPostAggregator("a2", "a2:agg", "population"))
|
||||
)
|
||||
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
|
||||
.build(),
|
||||
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStdDevSamp() throws Exception
|
||||
{
|
||||
queryLogHook.clearRecordedQueries();
|
||||
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
final String sql = "SELECT\n"
|
||||
+ "STDDEV_SAMP(d1),\n"
|
||||
+ "STDDEV_SAMP(f1),\n"
|
||||
+ "STDDEV_SAMP(l1)\n"
|
||||
+ "FROM numfoo";
|
||||
|
||||
final List<Object[]> results =
|
||||
sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
|
||||
|
||||
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
|
||||
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
|
||||
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
|
||||
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
|
||||
Object raw1 = row.getRaw("d1");
|
||||
Object raw2 = row.getRaw("f1");
|
||||
Object raw3 = row.getRaw("l1");
|
||||
addToHolder(holder1, raw1);
|
||||
addToHolder(holder2, raw2);
|
||||
addToHolder(holder3, raw3);
|
||||
}
|
||||
|
||||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
Math.sqrt(holder1.getVariance(false)),
|
||||
(float) Math.sqrt(holder2.getVariance(false)),
|
||||
(long) Math.sqrt(holder3.getVariance(false)),
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE3)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
|
||||
)
|
||||
)
|
||||
.postAggregators(
|
||||
new StandardDeviationPostAggregator("a0", "a0:agg", "sample"),
|
||||
new StandardDeviationPostAggregator("a1", "a1:agg", "sample"),
|
||||
new StandardDeviationPostAggregator("a2", "a2:agg", "sample")
|
||||
)
|
||||
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
|
||||
.build(),
|
||||
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStdDevWithVirtualColumns() throws Exception
|
||||
{
|
||||
queryLogHook.clearRecordedQueries();
|
||||
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
final String sql = "SELECT\n"
|
||||
+ "STDDEV(d1*7),\n"
|
||||
+ "STDDEV(f1*7),\n"
|
||||
+ "STDDEV(l1*7)\n"
|
||||
+ "FROM numfoo";
|
||||
|
||||
final List<Object[]> results =
|
||||
sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
|
||||
|
||||
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
|
||||
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
|
||||
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
|
||||
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
|
||||
Object raw1 = row.getRaw("d1");
|
||||
Object raw2 = row.getRaw("f1");
|
||||
Object raw3 = row.getRaw("l1");
|
||||
addToHolder(holder1, raw1, 7);
|
||||
addToHolder(holder2, raw2, 7);
|
||||
addToHolder(holder3, raw3, 7);
|
||||
}
|
||||
|
||||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
Math.sqrt(holder1.getVariance(false)),
|
||||
(float) Math.sqrt(holder2.getVariance(false)),
|
||||
(long) Math.sqrt(holder3.getVariance(false)),
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE3)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(Granularities.ALL)
|
||||
.virtualColumns(
|
||||
BaseCalciteQueryTest.expressionVirtualColumn("v0", "(\"d1\" * 7)", ValueType.DOUBLE),
|
||||
BaseCalciteQueryTest.expressionVirtualColumn("v1", "(\"f1\" * 7)", ValueType.FLOAT),
|
||||
BaseCalciteQueryTest.expressionVirtualColumn("v2", "(\"l1\" * 7)", ValueType.LONG)
|
||||
)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "v0", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"),
|
||||
new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long")
|
||||
)
|
||||
)
|
||||
.postAggregators(
|
||||
new StandardDeviationPostAggregator("a0", "a0:agg", "sample"),
|
||||
new StandardDeviationPostAggregator("a1", "a1:agg", "sample"),
|
||||
new StandardDeviationPostAggregator("a2", "a2:agg", "sample")
|
||||
)
|
||||
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
|
||||
.build(),
|
||||
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
|
||||
);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue