mirror of https://github.com/apache/druid.git
Support complex variance object inputs for variance SQL agg function (#14463)
* Support complex variance object inputs for variance SQL agg function * Add test * Include complexTypeChecker, address PR comments * Checkstyle, javadoc link
This commit is contained in:
parent
e552f68e77
commit
c36f12f1d8
|
@ -60,7 +60,7 @@ import java.util.Objects;
|
|||
@JsonTypeName("variance")
|
||||
public class VarianceAggregatorFactory extends AggregatorFactory
|
||||
{
|
||||
private static final String VARIANCE_TYPE_NAME = "variance";
|
||||
public static final String VARIANCE_TYPE_NAME = "variance";
|
||||
public static final ColumnType TYPE = ColumnType.ofComplex(VARIANCE_TYPE_NAME);
|
||||
|
||||
protected final String fieldName;
|
||||
|
|
|
@ -26,7 +26,10 @@ import org.apache.calcite.rel.type.RelDataType;
|
|||
import org.apache.calcite.rex.RexBuilder;
|
||||
import org.apache.calcite.rex.RexNode;
|
||||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
||||
import org.apache.calcite.sql.SqlFunctionCategory;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.ReturnTypes;
|
||||
import org.apache.druid.java.util.common.IAE;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
|
@ -42,15 +45,33 @@ import org.apache.druid.sql.calcite.aggregation.Aggregations;
|
|||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.expression.Expressions;
|
||||
import org.apache.druid.sql.calcite.expression.OperatorConversions;
|
||||
import org.apache.druid.sql.calcite.planner.Calcites;
|
||||
import org.apache.druid.sql.calcite.planner.PlannerContext;
|
||||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
import org.apache.druid.sql.calcite.table.RowSignatures;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
||||
{
|
||||
private static final String VARIANCE_NAME = "VARIANCE";
|
||||
private static final String STDDEV_NAME = "STDDEV";
|
||||
|
||||
private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(VARIANCE_NAME);
|
||||
private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(SqlKind.VAR_POP.name());
|
||||
private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name());
|
||||
private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(STDDEV_NAME);
|
||||
private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name());
|
||||
private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name());
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Aggregation toDruidAggregation(
|
||||
|
@ -104,12 +125,13 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
|
||||
if (inputType.isNumeric()) {
|
||||
inputTypeName = StringUtils.toLowerCase(inputType.getType().name());
|
||||
} else if (inputType.equals(VarianceAggregatorFactory.TYPE)) {
|
||||
inputTypeName = VarianceAggregatorFactory.VARIANCE_TYPE_NAME;
|
||||
} else {
|
||||
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType.asTypeString());
|
||||
}
|
||||
|
||||
|
||||
if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
|
||||
if (func.getName().equals(SqlKind.VAR_POP.name()) || func.getName().equals(SqlKind.STDDEV_POP.name())) {
|
||||
estimator = "population";
|
||||
} else {
|
||||
estimator = "sample";
|
||||
|
@ -122,9 +144,9 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
inputTypeName
|
||||
);
|
||||
|
||||
if (func == SqlStdOperatorTable.STDDEV_POP
|
||||
|| func == SqlStdOperatorTable.STDDEV_SAMP
|
||||
|| func == SqlStdOperatorTable.STDDEV) {
|
||||
if (func.getName().equals(STDDEV_NAME)
|
||||
|| func.getName().equals(SqlKind.STDDEV_POP.name())
|
||||
|| func.getName().equals(SqlKind.STDDEV_SAMP.name())) {
|
||||
postAggregator = new StandardDeviationPostAggregator(
|
||||
name,
|
||||
aggregatorFactory.getName(),
|
||||
|
@ -137,21 +159,40 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction}
|
||||
* but with an operand type that accepts variance aggregator objects in addition to numeric inputs.
|
||||
*/
|
||||
private static SqlAggFunction buildSqlAvgAggFunction(String name)
|
||||
{
|
||||
return OperatorConversions
|
||||
.aggregatorBuilder(name)
|
||||
.returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION)
|
||||
.operandTypeChecker(
|
||||
OperandTypes.or(
|
||||
OperandTypes.NUMERIC,
|
||||
RowSignatures.complexTypeChecker(VarianceAggregatorFactory.TYPE)
|
||||
)
|
||||
)
|
||||
.functionCategory(SqlFunctionCategory.NUMERIC)
|
||||
.build();
|
||||
}
|
||||
|
||||
public static class VarPopSqlAggregator extends BaseVarianceSqlAggregator
|
||||
{
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.VAR_POP;
|
||||
return VARIANCE_POP_SQL_AGG_FUNC_INSTANCE;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public static class VarSampSqlAggregator extends BaseVarianceSqlAggregator
|
||||
{
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.VAR_SAMP;
|
||||
return VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -160,7 +201,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.VARIANCE;
|
||||
return VARIANCE_SQL_AGG_FUNC_INSTANCE;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -169,7 +210,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.STDDEV_POP;
|
||||
return STDDEV_POP_SQL_AGG_FUNC_INSTANCE;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -178,7 +219,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.STDDEV_SAMP;
|
||||
return STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -187,7 +228,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
{
|
||||
return SqlStdOperatorTable.STDDEV;
|
||||
return STDDEV_SQL_AGG_FUNC_INSTANCE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ import org.apache.druid.query.aggregation.stats.DruidStatsModule;
|
|||
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
|
||||
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
|
||||
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.variance.VarianceSerde;
|
||||
import org.apache.druid.query.dimension.DefaultDimensionSpec;
|
||||
import org.apache.druid.query.groupby.GroupByQuery;
|
||||
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
|
||||
|
@ -51,6 +52,7 @@ import org.apache.druid.segment.QueryableIndex;
|
|||
import org.apache.druid.segment.column.ColumnType;
|
||||
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
|
||||
import org.apache.druid.segment.join.JoinableFactoryWrapper;
|
||||
import org.apache.druid.segment.serde.ComplexMetrics;
|
||||
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
|
||||
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
|
||||
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
|
||||
|
@ -82,8 +84,10 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
final Injector injector
|
||||
) throws IOException
|
||||
{
|
||||
ComplexMetrics.registerSerde(VarianceSerde.TYPE_NAME, new VarianceSerde());
|
||||
|
||||
final QueryableIndex index =
|
||||
IndexBuilder.create()
|
||||
IndexBuilder.create(CalciteTests.getJsonMapper().registerModules(new DruidStatsModule().getJacksonModules()))
|
||||
.tmpDir(temporaryFolder.newFolder())
|
||||
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
|
||||
.schema(
|
||||
|
@ -100,7 +104,8 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
)
|
||||
.withMetrics(
|
||||
new CountAggregatorFactory("cnt"),
|
||||
new DoubleSumAggregatorFactory("m1", "m1")
|
||||
new DoubleSumAggregatorFactory("m1", "m1"),
|
||||
new VarianceAggregatorFactory("var1", "m1", null, null)
|
||||
)
|
||||
.withRollup(false)
|
||||
.build()
|
||||
|
@ -624,6 +629,55 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVarianceAggAsInput()
|
||||
{
|
||||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
"3.5",
|
||||
"2.9166666666666665",
|
||||
"3.5",
|
||||
"1.8708286933869707",
|
||||
"1.707825127659933",
|
||||
"1.8708286933869707"
|
||||
}
|
||||
);
|
||||
testQuery(
|
||||
"SELECT\n"
|
||||
+ "VARIANCE(var1),\n"
|
||||
+ "VAR_POP(var1),\n"
|
||||
+ "VAR_SAMP(var1),\n"
|
||||
+ "STDDEV(var1),\n"
|
||||
+ "STDDEV_POP(var1),\n"
|
||||
+ "STDDEV_SAMP(var1)\n"
|
||||
+ "FROM numfoo",
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE3)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
ImmutableList.of(
|
||||
new VarianceAggregatorFactory("a0:agg", "var1", "sample", "variance"),
|
||||
new VarianceAggregatorFactory("a1:agg", "var1", "population", "variance"),
|
||||
new VarianceAggregatorFactory("a2:agg", "var1", "sample", "variance"),
|
||||
new VarianceAggregatorFactory("a3:agg", "var1", "sample", "variance"),
|
||||
new VarianceAggregatorFactory("a4:agg", "var1", "population", "variance"),
|
||||
new VarianceAggregatorFactory("a5:agg", "var1", "sample", "variance")
|
||||
)
|
||||
)
|
||||
.postAggregators(
|
||||
new StandardDeviationPostAggregator("a3", "a3:agg", "sample"),
|
||||
new StandardDeviationPostAggregator("a4", "a4:agg", "population"),
|
||||
new StandardDeviationPostAggregator("a5", "a5:agg", "sample")
|
||||
)
|
||||
.context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
expectedResults
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void assertResultsEquals(String sql, List<Object[]> expectedResults, List<Object[]> results)
|
||||
{
|
||||
|
|
|
@ -23,11 +23,18 @@ import com.google.common.base.Preconditions;
|
|||
import org.apache.calcite.rel.type.RelDataType;
|
||||
import org.apache.calcite.rel.type.RelDataTypeComparability;
|
||||
import org.apache.calcite.rel.type.RelDataTypeFactory;
|
||||
import org.apache.calcite.sql.SqlCallBinding;
|
||||
import org.apache.calcite.sql.SqlNode;
|
||||
import org.apache.calcite.sql.SqlOperandCountRange;
|
||||
import org.apache.calcite.sql.SqlOperator;
|
||||
import org.apache.calcite.sql.type.AbstractSqlType;
|
||||
import org.apache.calcite.sql.type.SqlOperandCountRanges;
|
||||
import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.java.util.common.IAE;
|
||||
import org.apache.druid.java.util.common.ISE;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.ordering.StringComparator;
|
||||
import org.apache.druid.query.ordering.StringComparators;
|
||||
import org.apache.druid.segment.column.ColumnHolder;
|
||||
|
@ -79,7 +86,9 @@ public class RowSignatures
|
|||
{
|
||||
Preconditions.checkNotNull(simpleExtraction, "simpleExtraction");
|
||||
if (simpleExtraction.getExtractionFn() != null
|
||||
|| rowSignature.getColumnType(simpleExtraction.getColumn()).map(type -> type.is(ValueType.STRING)).orElse(false)) {
|
||||
|| rowSignature.getColumnType(simpleExtraction.getColumn())
|
||||
.map(type -> type.is(ValueType.STRING))
|
||||
.orElse(false)) {
|
||||
return StringComparators.LEXICOGRAPHIC;
|
||||
} else {
|
||||
return StringComparators.NUMERIC;
|
||||
|
@ -164,7 +173,7 @@ public class RowSignatures
|
|||
* Creates a {@link ComplexSqlType} using the supplied {@link RelDataTypeFactory} to ensure that the
|
||||
* {@link ComplexSqlType} is interned. This is important because Calcite checks that the references are equal
|
||||
* instead of the objects being equivalent.
|
||||
*
|
||||
* <p>
|
||||
* This method uses {@link RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) ensures that if the
|
||||
* type factory is a {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type is passed through
|
||||
* {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which interns the type.
|
||||
|
@ -179,15 +188,15 @@ public class RowSignatures
|
|||
|
||||
/**
|
||||
* Calcite {@link RelDataType} for Druid complex columns, to preserve complex type information.
|
||||
*
|
||||
* <p>
|
||||
* If using with other operations of a {@link RelDataTypeFactory}, consider wrapping the creation of this type in
|
||||
* {@link RelDataTypeFactory#createTypeWithNullability(RelDataType, boolean) to ensure that if the type factory is a
|
||||
* {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl} that the type is passed through
|
||||
* {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#canonize(RelDataType)} which interns the type.
|
||||
*
|
||||
* <p>
|
||||
* If {@link SqlTypeName} is going to be {@link SqlTypeName#OTHER} and a {@link RelDataTypeFactory} is available,
|
||||
* consider using {@link #makeComplexType(RelDataTypeFactory, ColumnType, boolean)}.
|
||||
*
|
||||
* <p>
|
||||
* This type does not work well with {@link org.apache.calcite.sql.type.ReturnTypes#explicit(RelDataType)}, which
|
||||
* will create new {@link RelDataType} using {@link SqlTypeName} during return type inference, so implementors of
|
||||
* {@link org.apache.druid.sql.calcite.expression.SqlOperatorConversion} should implement the
|
||||
|
@ -235,4 +244,67 @@ public class RowSignatures
|
|||
return columnType.asTypeString();
|
||||
}
|
||||
}
|
||||
|
||||
public static ComplexSqlSingleOperandTypeChecker complexTypeChecker(ColumnType complexType)
|
||||
{
|
||||
return new ComplexSqlSingleOperandTypeChecker(
|
||||
new ComplexSqlType(SqlTypeName.OTHER, complexType, true)
|
||||
);
|
||||
}
|
||||
|
||||
public static final class ComplexSqlSingleOperandTypeChecker implements SqlSingleOperandTypeChecker
|
||||
{
|
||||
private final ComplexSqlType type;
|
||||
|
||||
public ComplexSqlSingleOperandTypeChecker(
|
||||
ComplexSqlType type
|
||||
)
|
||||
{
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean checkSingleOperandType(
|
||||
SqlCallBinding callBinding,
|
||||
SqlNode operand,
|
||||
int iFormalOperand,
|
||||
boolean throwOnFailure
|
||||
)
|
||||
{
|
||||
return type.equals(callBinding.getValidator().deriveType(callBinding.getScope(), operand));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure)
|
||||
{
|
||||
if (callBinding.getOperandCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
return checkSingleOperandType(callBinding, callBinding.operand(0), 0, throwOnFailure);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SqlOperandCountRange getOperandCountRange()
|
||||
{
|
||||
return SqlOperandCountRanges.of(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getAllowedSignatures(SqlOperator op, String opName)
|
||||
{
|
||||
return StringUtils.format("'%s'(%s)", opName, type);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Consistency getConsistency()
|
||||
{
|
||||
return Consistency.NONE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isOptional(int i)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue