Support complex variance object inputs for variance SQL agg function (#14463)

* Support complex variance object inputs for variance SQL agg function

* Add test

* Include complexTypeChecker, address PR comments

* Checkstyle, javadoc link
This commit is contained in:
Jonathan Wei 2023-06-28 13:14:19 -05:00 committed by GitHub
parent e552f68e77
commit c36f12f1d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 188 additions and 21 deletions

View File

@ -60,7 +60,7 @@ import java.util.Objects;
@JsonTypeName("variance")
public class VarianceAggregatorFactory extends AggregatorFactory
{
private static final String VARIANCE_TYPE_NAME = "variance";
public static final String VARIANCE_TYPE_NAME = "variance";
public static final ColumnType TYPE = ColumnType.ofComplex(VARIANCE_TYPE_NAME);
protected final String fieldName;

View File

@ -26,7 +26,10 @@ import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
@ -42,15 +45,33 @@ import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.table.RowSignatures;
import javax.annotation.Nullable;
import java.util.List;
public abstract class BaseVarianceSqlAggregator implements SqlAggregator
{
private static final String VARIANCE_NAME = "VARIANCE";
private static final String STDDEV_NAME = "STDDEV";
private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(VARIANCE_NAME);
private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.VAR_POP.name());
private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name());
private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(STDDEV_NAME);
private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name());
private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name());
@Nullable
@Override
public Aggregation toDruidAggregation(
@ -104,12 +125,13 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
if (inputType.isNumeric()) {
inputTypeName = StringUtils.toLowerCase(inputType.getType().name());
} else if (inputType.equals(VarianceAggregatorFactory.TYPE)) {
inputTypeName = VarianceAggregatorFactory.VARIANCE_TYPE_NAME;
} else {
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType.asTypeString());
}
if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
if (func.getName().equals(SqlKind.VAR_POP.name()) || func.getName().equals(SqlKind.STDDEV_POP.name())) {
estimator = "population";
} else {
estimator = "sample";
@ -122,9 +144,9 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
inputTypeName
);
if (func == SqlStdOperatorTable.STDDEV_POP
|| func == SqlStdOperatorTable.STDDEV_SAMP
|| func == SqlStdOperatorTable.STDDEV) {
if (func.getName().equals(STDDEV_NAME)
|| func.getName().equals(SqlKind.STDDEV_POP.name())
|| func.getName().equals(SqlKind.STDDEV_SAMP.name())) {
postAggregator = new StandardDeviationPostAggregator(
name,
aggregatorFactory.getName(),
@ -137,21 +159,40 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
);
}
/**
* Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction}
* but with an operand type that accepts variance aggregator objects in addition to numeric inputs.
*/
private static SqlAggFunction buildSqlAvgAggFunction(String name)
{
return OperatorConversions
.aggregatorBuilder(name)
.returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION)
.operandTypeChecker(
OperandTypes.or(
OperandTypes.NUMERIC,
RowSignatures.complexTypeChecker(VarianceAggregatorFactory.TYPE)
)
)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();
}
public static class VarPopSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.VAR_POP;
return VARIANCE_POP_SQL_AGG_FUNC_INSTANCE;
}
}
public static class VarSampSqlAggregator extends BaseVarianceSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.VAR_SAMP;
return VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE;
}
}
@ -160,7 +201,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.VARIANCE;
return VARIANCE_SQL_AGG_FUNC_INSTANCE;
}
}
@ -169,7 +210,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.STDDEV_POP;
return STDDEV_POP_SQL_AGG_FUNC_INSTANCE;
}
}
@ -178,7 +219,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.STDDEV_SAMP;
return STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE;
}
}
@ -187,7 +228,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.STDDEV;
return STDDEV_SQL_AGG_FUNC_INSTANCE;
}
}
}

View File

@ -40,6 +40,7 @@ import org.apache.druid.query.aggregation.stats.DruidStatsModule;
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
import org.apache.druid.query.aggregation.variance.VarianceSerde;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
@ -51,6 +52,7 @@ import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.apache.druid.segment.serde.ComplexMetrics;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
@ -82,8 +84,10 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
final Injector injector
) throws IOException
{
ComplexMetrics.registerSerde(VarianceSerde.TYPE_NAME, new VarianceSerde());
final QueryableIndex index =
IndexBuilder.create()
IndexBuilder.create(CalciteTests.getJsonMapper().registerModules(new DruidStatsModule().getJacksonModules()))
.tmpDir(temporaryFolder.newFolder())
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(
@ -100,7 +104,8 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
)
.withMetrics(
new CountAggregatorFactory("cnt"),
new DoubleSumAggregatorFactory("m1", "m1")
new DoubleSumAggregatorFactory("m1", "m1"),
new VarianceAggregatorFactory("var1", "m1", null, null)
)
.withRollup(false)
.build()
@ -624,6 +629,55 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
);
}
@Test
public void testVarianceAggAsInput()
{
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
"3.5",
"2.9166666666666665",
"3.5",
"1.8708286933869707",
"1.707825127659933",
"1.8708286933869707"
}
);
testQuery(
"SELECT\n"
+ "VARIANCE(var1),\n"
+ "VAR_POP(var1),\n"
+ "VAR_SAMP(var1),\n"
+ "STDDEV(var1),\n"
+ "STDDEV_POP(var1),\n"
+ "STDDEV_SAMP(var1)\n"
+ "FROM numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
ImmutableList.of(
new VarianceAggregatorFactory("a0:agg", "var1", "sample", "variance"),
new VarianceAggregatorFactory("a1:agg", "var1", "population", "variance"),
new VarianceAggregatorFactory("a2:agg", "var1", "sample", "variance"),
new VarianceAggregatorFactory("a3:agg", "var1", "sample", "variance"),
new VarianceAggregatorFactory("a4:agg", "var1", "population", "variance"),
new VarianceAggregatorFactory("a5:agg", "var1", "sample", "variance")
)
)
.postAggregators(
new StandardDeviationPostAggregator("a3", "a3:agg", "sample"),
new StandardDeviationPostAggregator("a4", "a4:agg", "population"),
new StandardDeviationPostAggregator("a5", "a5:agg", "sample")
)
.context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
.build()
),
expectedResults
);
}
@Override
public void assertResultsEquals(String sql, List<Object[]> expectedResults, List<Object[]> results)
{

View File

@ -23,11 +23,18 @@ import com.google.common.base.Preconditions;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeComparability;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.AbstractSqlType;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.ordering.StringComparator;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.segment.column.ColumnHolder;
@ -79,7 +86,9 @@ public class RowSignatures
{
Preconditions.checkNotNull(simpleExtraction, "simpleExtraction");
if (simpleExtraction.getExtractionFn() != null
|| rowSignature.getColumnType(simpleExtraction.getColumn()).map(type -> type.is(ValueType.STRING)).orElse(false)) {
|| rowSignature.getColumnType(simpleExtraction.getColumn())
.map(type -> type.is(ValueType.STRING))
.orElse(false)) {
return StringComparators.LEXICOGRAPHIC;
} else {
return StringComparators.NUMERIC;
@ -164,7 +173,7 @@ public class RowSignatures
* Creates a {@link ComplexSqlType} using the supplied {@link RelDataTypeFactory} to ensure that the
* {@link ComplexSqlType} is interned. This is important because Calcite checks that the references are equal
* instead of the objects being equivalent.
*
* <p>
* This method uses {@link RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) ensures that if the
* type factory is a {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type is passed through
* {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which interns the type.
@ -179,15 +188,15 @@ public class RowSignatures
/**
* Calcite {@link RelDataType} for Druid complex columns, to preserve complex type information.
*
* <p>
* If using with other operations of a {@link RelDataTypeFactory}, consider wrapping the creation of this type in
* {@link RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) to ensure that if the type factory is a
* {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type is passed through
* {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which interns the type.
*
* <p>
* If {@link SqlTypeName} is going to be {@link SqlTypeName#OTHER} and a {@link RelDataTypeFactory} is available,
* consider using {@link #makeComplexType(RelDataTypeFactory, ColumnType, boolean)}.
*
* <p>
* This type does not work well with {@link org.apache.calcite.sql.type.ReturnTypes#explicit(RelDataType)}, which
* will create new {@link RelDataType} using {@link SqlTypeName} during return type inference, so implementors of
* {@link org.apache.druid.sql.calcite.expression.SqlOperatorConversion} should implement the
@ -235,4 +244,67 @@ public class RowSignatures
return columnType.asTypeString();
}
}
public static ComplexSqlSingleOperandTypeChecker complexTypeChecker(ColumnType complexType)
{
return new ComplexSqlSingleOperandTypeChecker(
new ComplexSqlType(SqlTypeName.OTHER, complexType, true)
);
}
public static final class ComplexSqlSingleOperandTypeChecker implements SqlSingleOperandTypeChecker
{
private final ComplexSqlType type;
public ComplexSqlSingleOperandTypeChecker(
ComplexSqlType type
)
{
this.type = type;
}
@Override
public boolean checkSingleOperandType(
SqlCallBinding callBinding,
SqlNode operand,
int iFormalOperand,
boolean throwOnFailure
)
{
return type.equals(callBinding.getValidator().deriveType(callBinding.getScope(), operand));
}
@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure)
{
if (callBinding.getOperandCount() != 1) {
return false;
}
return checkSingleOperandType(callBinding, callBinding.operand(0), 0, throwOnFailure);
}
@Override
public SqlOperandCountRange getOperandCountRange()
{
return SqlOperandCountRanges.of(1);
}
@Override
public String getAllowedSignatures(SqlOperator op, String opName)
{
return StringUtils.format("'%s'(%s)", opName, type);
}
@Override
public Consistency getConsistency()
{
return Consistency.NONE;
}
@Override
public boolean isOptional(int i)
{
return false;
}
}
}