From c36f12f1d8bf81cf110dd41853627be175816b00 Mon Sep 17 00:00:00 2001 From: Jonathan Wei Date: Wed, 28 Jun 2023 13:14:19 -0500 Subject: [PATCH] Support complex variance object inputs for variance SQL agg function (#14463) * Support complex variance object inputs for variance SQL agg function * Add test * Include complexTypeChecker, address PR comments * Checkstyle, javadoc link --- .../variance/VarianceAggregatorFactory.java | 2 +- .../sql/BaseVarianceSqlAggregator.java | 67 ++++++++++++--- .../sql/VarianceSqlAggregatorTest.java | 58 ++++++++++++- .../sql/calcite/table/RowSignatures.java | 82 +++++++++++++++++-- 4 files changed, 188 insertions(+), 21 deletions(-) diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java index 47eccfbffd5..40d06bbbe02 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java @@ -60,7 +60,7 @@ import java.util.Objects; @JsonTypeName("variance") public class VarianceAggregatorFactory extends AggregatorFactory { - private static final String VARIANCE_TYPE_NAME = "variance"; + public static final String VARIANCE_TYPE_NAME = "variance"; public static final ColumnType TYPE = ColumnType.ofComplex(VARIANCE_TYPE_NAME); protected final String fieldName; diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java index 3eb3f498161..0b1562eb83d 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java @@ -26,7 +26,10 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -42,15 +45,33 @@ import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; +import org.apache.druid.sql.calcite.table.RowSignatures; import javax.annotation.Nullable; import java.util.List; public abstract class BaseVarianceSqlAggregator implements SqlAggregator { + private static final String VARIANCE_NAME = "VARIANCE"; + private static final String STDDEV_NAME = "STDDEV"; + + private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE = + buildSqlAvgAggFunction(VARIANCE_NAME); + private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE = + buildSqlAvgAggFunction(SqlKind.VAR_POP.name()); + private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE = + buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name()); + private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE = + buildSqlAvgAggFunction(STDDEV_NAME); + private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE = + buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name()); + private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE = + buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name()); + @Nullable @Override public Aggregation toDruidAggregation( @@ -104,12 +125,13 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator if (inputType.isNumeric()) { inputTypeName = StringUtils.toLowerCase(inputType.getType().name()); + } else if (inputType.equals(VarianceAggregatorFactory.TYPE)) { + inputTypeName = VarianceAggregatorFactory.VARIANCE_TYPE_NAME; } else { throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType.asTypeString()); } - - if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) { + if (func.getName().equals(SqlKind.VAR_POP.name()) || func.getName().equals(SqlKind.STDDEV_POP.name())) { estimator = "population"; } else { estimator = "sample"; @@ -122,9 +144,9 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator inputTypeName ); - if (func == SqlStdOperatorTable.STDDEV_POP - || func == SqlStdOperatorTable.STDDEV_SAMP - || func == SqlStdOperatorTable.STDDEV) { + if (func.getName().equals(STDDEV_NAME) + || func.getName().equals(SqlKind.STDDEV_POP.name()) + || func.getName().equals(SqlKind.STDDEV_SAMP.name())) { postAggregator = new StandardDeviationPostAggregator( name, aggregatorFactory.getName(), @@ -137,21 +159,40 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator ); } + /** + * Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction} + * but with an operand type that accepts variance aggregator objects in addition to numeric inputs. + */ + private static SqlAggFunction buildSqlAvgAggFunction(String name) + { + return OperatorConversions + .aggregatorBuilder(name) + .returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION) + .operandTypeChecker( + OperandTypes.or( + OperandTypes.NUMERIC, + RowSignatures.complexTypeChecker(VarianceAggregatorFactory.TYPE) + ) + ) + .functionCategory(SqlFunctionCategory.NUMERIC) + .build(); + } + public static class VarPopSqlAggregator extends BaseVarianceSqlAggregator { @Override public SqlAggFunction calciteFunction() { - return SqlStdOperatorTable.VAR_POP; + return VARIANCE_POP_SQL_AGG_FUNC_INSTANCE; } } - + public static class VarSampSqlAggregator extends BaseVarianceSqlAggregator { @Override public SqlAggFunction calciteFunction() { - return SqlStdOperatorTable.VAR_SAMP; + return VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE; } } @@ -160,7 +201,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator @Override public SqlAggFunction calciteFunction() { - return SqlStdOperatorTable.VARIANCE; + return VARIANCE_SQL_AGG_FUNC_INSTANCE; } } @@ -169,7 +210,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator @Override public SqlAggFunction calciteFunction() { - return SqlStdOperatorTable.STDDEV_POP; + return STDDEV_POP_SQL_AGG_FUNC_INSTANCE; } } @@ -178,7 +219,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator @Override public SqlAggFunction calciteFunction() { - return SqlStdOperatorTable.STDDEV_SAMP; + return STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE; } } @@ -187,7 +228,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator @Override public SqlAggFunction calciteFunction() { - return SqlStdOperatorTable.STDDEV; + return STDDEV_SQL_AGG_FUNC_INSTANCE; } } } diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java index bc1ef681693..5c496c46635 100644 --- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java @@ -40,6 +40,7 @@ import org.apache.druid.query.aggregation.stats.DruidStatsModule; import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator; import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector; import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory; +import org.apache.druid.query.aggregation.variance.VarianceSerde; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.orderby.DefaultLimitSpec; @@ -51,6 +52,7 @@ import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.incremental.IncrementalIndexSchema; import org.apache.druid.segment.join.JoinableFactoryWrapper; +import org.apache.druid.segment.serde.ComplexMetrics; import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; import org.apache.druid.sql.calcite.BaseCalciteQueryTest; @@ -82,8 +84,10 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest final Injector injector ) throws IOException { + ComplexMetrics.registerSerde(VarianceSerde.TYPE_NAME, new VarianceSerde()); + final QueryableIndex index = - IndexBuilder.create() + IndexBuilder.create(CalciteTests.getJsonMapper().registerModules(new DruidStatsModule().getJacksonModules())) .tmpDir(temporaryFolder.newFolder()) .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) .schema( @@ -100,7 +104,8 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest ) .withMetrics( new CountAggregatorFactory("cnt"), - new DoubleSumAggregatorFactory("m1", "m1") + new DoubleSumAggregatorFactory("m1", "m1"), + new VarianceAggregatorFactory("var1", "m1", null, null) ) .withRollup(false) .build() @@ -624,6 +629,55 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest ); } + @Test + public void testVarianceAggAsInput() + { + final List expectedResults = ImmutableList.of( + new Object[]{ + "3.5", + "2.9166666666666665", + "3.5", + "1.8708286933869707", + "1.707825127659933", + "1.8708286933869707" + } + ); + testQuery( + "SELECT\n" + + "VARIANCE(var1),\n" + + "VAR_POP(var1),\n" + + "VAR_SAMP(var1),\n" + + "STDDEV(var1),\n" + + "STDDEV_POP(var1),\n" + + "STDDEV_SAMP(var1)\n" + + "FROM numfoo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(Granularities.ALL) + .aggregators( + ImmutableList.of( + new VarianceAggregatorFactory("a0:agg", "var1", "sample", "variance"), + new VarianceAggregatorFactory("a1:agg", "var1", "population", "variance"), + new VarianceAggregatorFactory("a2:agg", "var1", "sample", "variance"), + new VarianceAggregatorFactory("a3:agg", "var1", "sample", "variance"), + new VarianceAggregatorFactory("a4:agg", "var1", "population", "variance"), + new VarianceAggregatorFactory("a5:agg", "var1", "sample", "variance") + ) + ) + .postAggregators( + new StandardDeviationPostAggregator("a3", "a3:agg", "sample"), + new StandardDeviationPostAggregator("a4", "a4:agg", "population"), + new StandardDeviationPostAggregator("a5", "a5:agg", "sample") + ) + .context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT) + .build() + ), + expectedResults + ); + } + @Override public void assertResultsEquals(String sql, List expectedResults, List results) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java b/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java index 32abe56ee8d..87519c75374 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/table/RowSignatures.java @@ -23,11 +23,18 @@ import com.google.common.base.Preconditions; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeComparability; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperandCountRange; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.AbstractSqlType; +import org.apache.calcite.sql.type.SqlOperandCountRanges; +import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.ordering.StringComparator; import org.apache.druid.query.ordering.StringComparators; import org.apache.druid.segment.column.ColumnHolder; @@ -79,7 +86,9 @@ public class RowSignatures { Preconditions.checkNotNull(simpleExtraction, "simpleExtraction"); if (simpleExtraction.getExtractionFn() != null - || rowSignature.getColumnType(simpleExtraction.getColumn()).map(type -> type.is(ValueType.STRING)).orElse(false)) { + || rowSignature.getColumnType(simpleExtraction.getColumn()) + .map(type -> type.is(ValueType.STRING)) + .orElse(false)) { return StringComparators.LEXICOGRAPHIC; } else { return StringComparators.NUMERIC; @@ -164,7 +173,7 @@ public class RowSignatures * Creates a {@link ComplexSqlType} using the supplied {@link RelDataTypeFactory} to ensure that the * {@link ComplexSqlType} is interned. This is important because Calcite checks that the references are equal * instead of the objects being equivalent. - * + *

* This method uses {@link RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) ensures that if the * type factory is a {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type is passed through * {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which interns the type. @@ -179,15 +188,15 @@ public class RowSignatures /** * Calcite {@link RelDataType} for Druid complex columns, to preserve complex type information. - * + *

* If using with other operations of a {@link RelDataTypeFactory}, consider wrapping the creation of this type in * {@link RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) to ensure that if the type factory is a * {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type is passed through * {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which interns the type. - * + *

* If {@link SqlTypeName} is going to be {@link SqlTypeName#OTHER} and a {@link RelDataTypeFactory} is available, * consider using {@link #makeComplexType(RelDataTypeFactory, ColumnType, boolean)}. - * + *

* This type does not work well with {@link org.apache.calcite.sql.type.ReturnTypes#explicit(RelDataType)}, which * will create new {@link RelDataType} using {@link SqlTypeName} during return type inference, so implementors of * {@link org.apache.druid.sql.calcite.expression.SqlOperatorConversion} should implement the @@ -235,4 +244,67 @@ public class RowSignatures return columnType.asTypeString(); } } + + public static ComplexSqlSingleOperandTypeChecker complexTypeChecker(ColumnType complexType) + { + return new ComplexSqlSingleOperandTypeChecker( + new ComplexSqlType(SqlTypeName.OTHER, complexType, true) + ); + } + + public static final class ComplexSqlSingleOperandTypeChecker implements SqlSingleOperandTypeChecker + { + private final ComplexSqlType type; + + public ComplexSqlSingleOperandTypeChecker( + ComplexSqlType type + ) + { + this.type = type; + } + + @Override + public boolean checkSingleOperandType( + SqlCallBinding callBinding, + SqlNode operand, + int iFormalOperand, + boolean throwOnFailure + ) + { + return type.equals(callBinding.getValidator().deriveType(callBinding.getScope(), operand)); + } + + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) + { + if (callBinding.getOperandCount() != 1) { + return false; + } + return checkSingleOperandType(callBinding, callBinding.operand(0), 0, throwOnFailure); + } + + @Override + public SqlOperandCountRange getOperandCountRange() + { + return SqlOperandCountRanges.of(1); + } + + @Override + public String getAllowedSignatures(SqlOperator op, String opName) + { + return StringUtils.format("'%s'(%s)", opName, type); + } + + @Override + public Consistency getConsistency() + { + return Consistency.NONE; + } + + @Override + public boolean isOptional(int i) + { + return false; + } + } }