diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeAggregatorFactory.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeAggregatorFactory.java index 833df8ab1a5..7d0bb79c71e 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeAggregatorFactory.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeAggregatorFactory.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.apache.datasketches.hll.HllSketch; import org.apache.datasketches.hll.TgtHllType; import org.apache.datasketches.hll.Union; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.StringEncoding; import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -34,7 +35,9 @@ import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.segment.ColumnInspector; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import javax.annotation.Nullable; @@ -107,6 +110,8 @@ public class HllSketchMergeAggregatorFactory extends HllSketchAggregatorFactory @Override public Aggregator factorize(final ColumnSelectorFactory columnSelectorFactory) { + validateInputs(columnSelectorFactory.getColumnCapabilities(getFieldName())); + final ColumnValueSelector selector = columnSelectorFactory.makeColumnValueSelector(getFieldName()); return new HllSketchMergeAggregator(selector, getLgK(), TgtHllType.valueOf(getTgtHllType())); } @@ -115,6 +120,8 @@ public class HllSketchMergeAggregatorFactory extends HllSketchAggregatorFactory @Override public BufferAggregator factorizeBuffered(final ColumnSelectorFactory columnSelectorFactory) { + validateInputs(columnSelectorFactory.getColumnCapabilities(getFieldName())); + final ColumnValueSelector selector = columnSelectorFactory.makeColumnValueSelector(getFieldName()); return new HllSketchMergeBufferAggregator( selector, @@ -133,6 +140,7 @@ public class HllSketchMergeAggregatorFactory extends HllSketchAggregatorFactory @Override public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory) { + validateInputs(selectorFactory.getColumnCapabilities(getFieldName())); return new HllSketchMergeVectorAggregator( selectorFactory, getFieldName(), @@ -142,6 +150,34 @@ public class HllSketchMergeAggregatorFactory extends HllSketchAggregatorFactory ); } + /** + * Validates whether the aggregator supports the input column type. + * Supported column types are complex types of HLLSketch, HLLSketchBuild, HLLSketchMerge, as well as UNKNOWN_COMPLEX. + * @param capabilities + */ + private void validateInputs(@Nullable ColumnCapabilities capabilities) + { + if (capabilities != null) { + final ColumnType type = capabilities.toColumnType(); + boolean isSupportedComplexType = ValueType.COMPLEX.equals(type.getType()) && + ( + HllSketchModule.TYPE_NAME.equals(type.getComplexTypeName()) || + HllSketchModule.BUILD_TYPE_NAME.equals(type.getComplexTypeName()) || + HllSketchModule.MERGE_TYPE_NAME.equals(type.getComplexTypeName()) || + type.getComplexTypeName() == null + ); + if (!isSupportedComplexType) { + throw DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.UNSUPPORTED) + .build( + "Using aggregator [%s] is not supported for complex columns with type [%s].", + getIntermediateType().getComplexTypeName(), + type + ); + } + } + } + @Override public int getMaxIntermediateSize() { diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java index 757674d6aa6..7d303d27274 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java @@ -21,28 +21,68 @@ package org.apache.druid.query.aggregation.datasketches.hll.sql; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers; import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.java.util.common.StringEncoding; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.datasketches.hll.HllSketchBuildAggregatorFactory; +import org.apache.druid.query.aggregation.datasketches.hll.HllSketchMergeAggregatorFactory; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.OperatorConversions; +import org.apache.druid.sql.calcite.table.RowSignatures; import java.util.Collections; +/** + * Approximate count distinct aggregator using HLL sketches. + * Supported column types: String, Numeric, HLLSketchMerge, HLLSketchBuild. + */ public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator { public static final String NAME = "APPROX_COUNT_DISTINCT_DS_HLL"; + + private static final SqlSingleOperandTypeChecker AGGREGATED_COLUMN_TYPE_CHECKER = OperandTypes.or( + OperandTypes.STRING, + OperandTypes.NUMERIC, + RowSignatures.complexTypeChecker(HllSketchMergeAggregatorFactory.TYPE), + RowSignatures.complexTypeChecker(HllSketchBuildAggregatorFactory.TYPE) + ); + private static final SqlAggFunction FUNCTION_INSTANCE = OperatorConversions.aggregatorBuilder(NAME) - .operandNames("column", "lgK", "tgtHllType") - .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING) .operandTypeInference(InferTypes.VARCHAR_1024) - .requiredOperandCount(1) - .literalOperands(1, 2) + .operandTypeChecker( + OperandTypes.or( + // APPROX_COUNT_DISTINCT_DS_HLL(column) + AGGREGATED_COLUMN_TYPE_CHECKER, + // APPROX_COUNT_DISTINCT_DS_HLL(column, lgk) + OperandTypes.and( + OperandTypes.sequence( + StringUtils.format("'%s(column, lgk)'", NAME), + AGGREGATED_COLUMN_TYPE_CHECKER, + CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL + ), + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC) + ), + // APPROX_COUNT_DISTINCT_DS_HLL(column, lgk, tgtHllType) + OperandTypes.and( + OperandTypes.sequence( + StringUtils.format("'%s(column, lgk, tgtHllType)'", NAME), + AGGREGATED_COLUMN_TYPE_CHECKER, + CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL, + OperandTypes.STRING + ), + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC, SqlTypeFamily.STRING) + ) + ) + ) .returnTypeNonNull(SqlTypeName.BIGINT) .functionCategory(SqlFunctionCategory.NUMERIC) .build(); diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java index d221b72ac1c..15221c0f6f8 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java @@ -31,6 +31,7 @@ import org.apache.druid.query.aggregation.datasketches.SketchQueryContext; import org.apache.druid.query.aggregation.datasketches.hll.HllSketchAggregatorFactory; import org.apache.druid.query.aggregation.datasketches.hll.HllSketchBuildAggregatorFactory; import org.apache.druid.query.aggregation.datasketches.hll.HllSketchMergeAggregatorFactory; +import org.apache.druid.query.aggregation.datasketches.hll.HllSketchModule; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; @@ -40,6 +41,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PlannerConfig; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; @@ -115,7 +117,7 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator if (columnArg.isDirectColumnAccess() && inputAccessor.getInputRowSignature() .getColumnType(columnArg.getDirectColumn()) - .map(type -> type.is(ValueType.COMPLEX)) + .map(this::isValidComplexInputType) .orElse(false)) { aggregatorFactory = new HllSketchMergeAggregatorFactory( aggregatorName, @@ -154,6 +156,15 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator } if (inputType.is(ValueType.COMPLEX)) { + if (!isValidComplexInputType(inputType)) { + plannerContext.setPlanningError( + "Using APPROX_COUNT_DISTINCT() or enabling approximation with COUNT(DISTINCT) is not supported for" + + " column type [%s]. You can disable approximation by setting [%s: false] in the query context.", + columnArg.getDruidType(), + PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT + ); + return null; + } aggregatorFactory = new HllSketchMergeAggregatorFactory( aggregatorName, dimensionSpec.getOutputName(), @@ -192,4 +203,11 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator boolean finalizeAggregations, AggregatorFactory aggregatorFactory ); + + private boolean isValidComplexInputType(ColumnType columnType) + { + return HllSketchMergeAggregatorFactory.TYPE.equals(columnType) || + HllSketchModule.TYPE_NAME.equals(columnType.getComplexTypeName()) || + HllSketchModule.BUILD_TYPE_NAME.equals(columnType.getComplexTypeName()); + } } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchAggregatorFactory.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchAggregatorFactory.java index 211373e873b..b24e382ec0a 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchAggregatorFactory.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchAggregatorFactory.java @@ -27,7 +27,7 @@ import org.apache.datasketches.common.Util; import org.apache.datasketches.theta.SetOperation; import org.apache.datasketches.theta.Union; import org.apache.datasketches.thetacommon.ThetaUtil; -import org.apache.druid.error.InvalidInput; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.Aggregator; @@ -41,6 +41,7 @@ import org.apache.druid.segment.ColumnInspector; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import javax.annotation.Nullable; @@ -80,10 +81,7 @@ public abstract class SketchAggregatorFactory extends AggregatorFactory @Override public Aggregator factorize(ColumnSelectorFactory metricFactory) { - ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName); - if (capabilities != null && capabilities.isArray()) { - throw InvalidInput.exception("ARRAY types are not supported for theta sketch"); - } + validateInputs(metricFactory.getColumnCapabilities(fieldName)); BaseObjectColumnValueSelector selector = metricFactory.makeColumnValueSelector(fieldName); return new SketchAggregator(selector, size); } @@ -91,10 +89,7 @@ public abstract class SketchAggregatorFactory extends AggregatorFactory @Override public AggregatorAndSize factorizeWithSize(ColumnSelectorFactory metricFactory) { - ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName); - if (capabilities != null && capabilities.isArray()) { - throw InvalidInput.exception("ARRAY types are not supported for theta sketch"); - } + validateInputs(metricFactory.getColumnCapabilities(fieldName)); BaseObjectColumnValueSelector selector = metricFactory.makeColumnValueSelector(fieldName); final SketchAggregator aggregator = new SketchAggregator(selector, size); return new AggregatorAndSize(aggregator, aggregator.getInitialSizeBytes()); @@ -104,10 +99,7 @@ public abstract class SketchAggregatorFactory extends AggregatorFactory @Override public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) { - ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName); - if (capabilities != null && capabilities.isArray()) { - throw InvalidInput.exception("ARRAY types are not supported for theta sketch"); - } + validateInputs(metricFactory.getColumnCapabilities(fieldName)); BaseObjectColumnValueSelector selector = metricFactory.makeColumnValueSelector(fieldName); return new SketchBufferAggregator(selector, size, getMaxIntermediateSizeWithNulls()); } @@ -115,9 +107,41 @@ public abstract class SketchAggregatorFactory extends AggregatorFactory @Override public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory) { + validateInputs(selectorFactory.getColumnCapabilities(fieldName)); return new SketchVectorAggregator(selectorFactory, fieldName, size, getMaxIntermediateSizeWithNulls()); } + /** + * Validates whether the aggregator supports the input column type. + * Unsupported column types are: + * + * @param capabilities + */ + private void validateInputs(@Nullable ColumnCapabilities capabilities) + { + if (capabilities != null) { + boolean isUnsupportedComplexType = capabilities.is(ValueType.COMPLEX) && !( + SketchModule.THETA_SKETCH_TYPE.equals(capabilities.toColumnType()) || + SketchModule.MERGE_TYPE.equals(capabilities.toColumnType()) || + SketchModule.BUILD_TYPE.equals(capabilities.toColumnType()) + ); + + if (capabilities.isArray() || isUnsupportedComplexType) { + throw DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.UNSUPPORTED) + .build( + "Unsupported input [%s] of type [%s] for aggregator [%s].", + getFieldName(), + capabilities.asTypeString(), + getIntermediateType() + ); + } + } + } + @Override public boolean canVectorize(ColumnInspector columnInspector) { diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java index eac77901f1d..5ecd289c728 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java @@ -21,27 +21,55 @@ package org.apache.druid.query.aggregation.datasketches.theta.sql; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers; import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.datasketches.theta.SketchModule; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.OperatorConversions; +import org.apache.druid.sql.calcite.table.RowSignatures; import java.util.Collections; +/** + * Approximate count distinct aggregator using theta sketches. + * Supported column types: String, Numeric, Theta Sketch. + */ public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator { public static final String NAME = "APPROX_COUNT_DISTINCT_DS_THETA"; + + private static final SqlSingleOperandTypeChecker AGGREGATED_COLUMN_TYPE_CHECKER = OperandTypes.or( + OperandTypes.STRING, + OperandTypes.NUMERIC, + RowSignatures.complexTypeChecker(SketchModule.THETA_SKETCH_TYPE) + ); + private static final SqlAggFunction FUNCTION_INSTANCE = OperatorConversions.aggregatorBuilder(NAME) - .operandNames("column", "size") - .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) .operandTypeInference(InferTypes.VARCHAR_1024) - .requiredOperandCount(1) - .literalOperands(1) + .operandTypeChecker( + OperandTypes.or( + // APPROX_COUNT_DISTINCT_DS_THETA(expr) + AGGREGATED_COLUMN_TYPE_CHECKER, + // APPROX_COUNT_DISTINCT_DS_THETA(expr, size) + OperandTypes.and( + OperandTypes.sequence( + StringUtils.format("'%s(expr, size)'", NAME), + AGGREGATED_COLUMN_TYPE_CHECKER, + CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL + ), + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC) + ) + ) + ) .returnTypeNonNull(SqlTypeName.BIGINT) .functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION) .build(); diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java index bf35cd665ae..1f45f31496a 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java @@ -29,6 +29,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.datasketches.SketchQueryContext; import org.apache.druid.query.aggregation.datasketches.theta.SketchAggregatorFactory; import org.apache.druid.query.aggregation.datasketches.theta.SketchMergeAggregatorFactory; +import org.apache.druid.query.aggregation.datasketches.theta.SketchModule; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; @@ -38,6 +39,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PlannerConfig; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; @@ -95,7 +97,11 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator if (columnArg.isDirectColumnAccess() && inputAccessor.getInputRowSignature() .getColumnType(columnArg.getDirectColumn()) - .map(type -> type.is(ValueType.COMPLEX)) + .map(type -> ( + SketchModule.THETA_SKETCH_TYPE.equals(type) || + SketchModule.MERGE_TYPE.equals(type) || + SketchModule.BUILD_TYPE.equals(type) + )) .orElse(false)) { aggregatorFactory = new SketchMergeAggregatorFactory( aggregatorName, @@ -116,6 +122,16 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator ); } + if (inputType.is(ValueType.COMPLEX)) { + plannerContext.setPlanningError( + "Using APPROX_COUNT_DISTINCT() or enabling approximation with COUNT(DISTINCT) is not supported for" + + " column type [%s]. You can disable approximation by setting [%s: false] in the query context.", + columnArg.getDruidType(), + PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT + ); + return null; + } + final DimensionSpec dimensionSpec; if (columnArg.isDirectColumnAccess()) { diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeAggregatorFactoryTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeAggregatorFactoryTest.java index 101b25b99be..fcecef62d4a 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeAggregatorFactoryTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeAggregatorFactoryTest.java @@ -22,10 +22,17 @@ package org.apache.druid.query.aggregation.datasketches.hll; import com.fasterxml.jackson.databind.ObjectMapper; import nl.jqno.equalsverifier.EqualsVerifier; import org.apache.datasketches.hll.TgtHllType; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.StringEncoding; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactoryNotMergeableException; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.TestColumnSelectorFactory; import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.vector.TestVectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -45,6 +52,9 @@ public class HllSketchMergeAggregatorFactoryTest private HllSketchMergeAggregatorFactory targetRound; private HllSketchMergeAggregatorFactory targetNoRound; + private ColumnSelectorFactory metricFactory; + private VectorColumnSelectorFactory vectorFactory; + @Before public void setUp() { @@ -66,6 +76,10 @@ public class HllSketchMergeAggregatorFactoryTest SHOULD_FINALIZE, !ROUND ); + + final ColumnCapabilitiesImpl columnCapabilities = ColumnCapabilitiesImpl.createDefault().setType(ColumnType.NESTED_DATA); + metricFactory = new TestColumnSelectorFactory().addCapabilities(FIELD_NAME, columnCapabilities); + vectorFactory = new TestVectorColumnSelectorFactory().addCapabilities(FIELD_NAME, columnCapabilities); } @Test(expected = AggregatorFactoryNotMergeableException.class) @@ -291,4 +305,39 @@ public class HllSketchMergeAggregatorFactoryTest Assert.assertEquals(factory, factory.withName(targetRound.getName())); Assert.assertEquals("newTest", factory.withName("newTest").getName()); } + + @Test + public void testFactorizeOnUnsupportedComplexColumn() + { + final ColumnSelectorFactory metricFactory = new TestColumnSelectorFactory() + .addCapabilities( + FIELD_NAME, + ColumnCapabilitiesImpl.createDefault().setType(ColumnType.NESTED_DATA) + ); + Throwable exception = Assert.assertThrows(DruidException.class, () -> targetRound.factorize(metricFactory)); + Assert.assertEquals( + "Using aggregator [HLLSketchMerge] is not supported for complex columns with type [COMPLEX].", + exception.getMessage() + ); + } + + @Test + public void testFactorizeBufferedOnUnsupportedComplexColumn() + { + Throwable exception = Assert.assertThrows(DruidException.class, () -> targetRound.factorizeBuffered(metricFactory)); + Assert.assertEquals( + "Using aggregator [HLLSketchMerge] is not supported for complex columns with type [COMPLEX].", + exception.getMessage() + ); + } + + @Test + public void testFactorizeVectorOnUnsupportedComplexColumn() + { + Throwable exception = Assert.assertThrows(DruidException.class, () -> targetRound.factorizeVector(vectorFactory)); + Assert.assertEquals( + "Using aggregator [HLLSketchMerge] is not supported for complex columns with type [COMPLEX].", + exception.getMessage() + ); + } } diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index 2907d6f8bb8..edb7dc5a11f 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -24,7 +24,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.inject.Injector; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.error.DruidException; import org.apache.druid.guice.DruidInjectorBuilder; +import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.StringEncoding; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; @@ -86,6 +88,7 @@ import org.apache.druid.sql.calcite.util.TestDataBuilder; import org.apache.druid.sql.guice.SqlModule; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.LinearShardSpec; +import org.apache.druid.timeline.partition.NumberedShardSpec; import org.joda.time.DateTimeZone; import org.joda.time.Period; import org.junit.Assert; @@ -100,6 +103,10 @@ import java.util.stream.Collectors; @SqlTestFrameworkConfig.ComponentSupplier(HllSketchComponentSupplier.class) public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest { + static { + NullHandling.initializeForTests(); + } + private static final boolean ROUND = true; // For testHllSketchPostAggsGroupBy, testHllSketchPostAggsTimeseries @@ -300,6 +307,15 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest .size(0) .build(), index + ).add( + DataSegment.builder() + .dataSource(CalciteTests.WIKIPEDIA_FIRST_LAST) + .interval(Intervals.of("2015-09-12/2015-09-13")) + .version("1") + .shardSpec(new NumberedShardSpec(0, 0)) + .size(0) + .build(), + TestDataBuilder.makeWikipediaIndexWithAggregation(tempDirProducer.newTempFolder()) ); } } @@ -508,6 +524,33 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest ); } + @Test + public void testApproxCountDistinctOnUnsupportedComplexColumn() + { + assertQueryIsUnplannable( + "SELECT COUNT(distinct double_first_added) FROM druid.wikipedia_first_last", + "Query could not be planned. A possible reason is [Using APPROX_COUNT_DISTINCT() or enabling " + + "approximation with COUNT(DISTINCT) is not supported for column type [COMPLEX]." + + " You can disable approximation by setting [useApproximateCountDistinct: false] in the query context." + ); + } + + @Test + public void testApproxCountDistinctFunctionOnUnsupportedComplexColumn() + { + DruidException druidException = Assert.assertThrows( + DruidException.class, + () -> testQuery( + "SELECT APPROX_COUNT_DISTINCT_DS_HLL(double_first_added) FROM druid.wikipedia_first_last", + ImmutableList.of(), + ImmutableList.of() + ) + ); + Assert.assertTrue(druidException.getMessage().contains( + "Cannot apply 'APPROX_COUNT_DISTINCT_DS_HLL' to arguments of type 'APPROX_COUNT_DISTINCT_DS_HLL(>)'" + )); + } + @Test public void testHllSketchFilteredAggregatorsGroupBy() { diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchAggregatorFactoryTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchAggregatorFactoryTest.java index 1d70ff30f25..23887652a73 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchAggregatorFactoryTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchAggregatorFactoryTest.java @@ -20,6 +20,7 @@ package org.apache.druid.query.aggregation.datasketches.theta; import com.google.common.collect.ImmutableList; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.query.Druids; import org.apache.druid.query.aggregation.AggregatorAndSize; @@ -32,10 +33,15 @@ import org.apache.druid.query.timeseries.TimeseriesQuery; import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.TestColumnSelectorFactory; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.vector.TestVectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.easymock.EasyMock; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; public class SketchAggregatorFactoryTest @@ -46,6 +52,17 @@ public class SketchAggregatorFactoryTest private static final SketchMergeAggregatorFactory AGGREGATOR_32768 = new SketchMergeAggregatorFactory("x", "x", 32768, null, false, null); + private ColumnSelectorFactory metricFactory; + private VectorColumnSelectorFactory vectorFactory; + + @Before + public void setup() + { + final ColumnCapabilitiesImpl columnCapabilities = ColumnCapabilitiesImpl.createDefault().setType(ColumnType.NESTED_DATA); + metricFactory = new TestColumnSelectorFactory().addCapabilities("x", columnCapabilities); + vectorFactory = new TestVectorColumnSelectorFactory().addCapabilities("x", columnCapabilities); + } + @Test public void testGuessAggregatorHeapFootprint() { @@ -168,4 +185,32 @@ public class SketchAggregatorFactoryTest Assert.assertEquals(AGGREGATOR_16384, AGGREGATOR_16384.withName("x")); Assert.assertEquals("newTest", AGGREGATOR_16384.withName("newTest").getName()); } + + @Test + public void testFactorizeOnUnsupportedComplexColumn() + { + Throwable exception = Assert.assertThrows(DruidException.class, () -> AGGREGATOR_16384.factorize(metricFactory)); + Assert.assertEquals("Unsupported input [x] of type [COMPLEX] for aggregator [COMPLEX].", exception.getMessage()); + } + + @Test + public void testFactorizeWithSizeOnUnsupportedComplexColumn() + { + Throwable exception = Assert.assertThrows(DruidException.class, () -> AGGREGATOR_16384.factorizeWithSize(metricFactory)); + Assert.assertEquals("Unsupported input [x] of type [COMPLEX] for aggregator [COMPLEX].", exception.getMessage()); + } + + @Test + public void testFactorizeBufferedOnUnsupportedComplexColumn() + { + Throwable exception = Assert.assertThrows(DruidException.class, () -> AGGREGATOR_16384.factorizeBuffered(metricFactory)); + Assert.assertEquals("Unsupported input [x] of type [COMPLEX] for aggregator [COMPLEX].", exception.getMessage()); + } + + @Test + public void testFactorizeVectorOnUnsupportedComplexColumn() + { + Throwable exception = Assert.assertThrows(DruidException.class, () -> AGGREGATOR_16384.factorizeVector(vectorFactory)); + Assert.assertEquals("Unsupported input [x] of type [COMPLEX] for aggregator [COMPLEX].", exception.getMessage()); + } } diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java index 247f924357a..7afd2710ccd 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java @@ -23,7 +23,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.inject.Injector; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.error.DruidException; import org.apache.druid.guice.DruidInjectorBuilder; +import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.PeriodGranularity; @@ -71,6 +73,7 @@ import org.apache.druid.sql.calcite.util.TestDataBuilder; import org.apache.druid.sql.guice.SqlModule; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.LinearShardSpec; +import org.apache.druid.timeline.partition.NumberedShardSpec; import org.joda.time.DateTimeZone; import org.joda.time.Period; import org.junit.Assert; @@ -158,6 +161,15 @@ public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest .size(0) .build(), index + ).add( + DataSegment.builder() + .dataSource(CalciteTests.WIKIPEDIA_FIRST_LAST) + .interval(Intervals.of("2015-09-12/2015-09-13")) + .version("1") + .shardSpec(new NumberedShardSpec(0, 0)) + .size(0) + .build(), + TestDataBuilder.makeWikipediaIndexWithAggregation(tempDirProducer.newTempFolder()) ); } } @@ -373,6 +385,33 @@ public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest ); } + @Test + public void testApproxCountDistinctOnUnsupportedComplexColumn() + { + assertQueryIsUnplannable( + "SELECT COUNT(distinct double_first_added) FROM druid.wikipedia_first_last", + "Query could not be planned. A possible reason is [Using APPROX_COUNT_DISTINCT() or enabling " + + "approximation with COUNT(DISTINCT) is not supported for column type [COMPLEX]." + + " You can disable approximation by setting [useApproximateCountDistinct: false] in the query context." + ); + } + + @Test + public void testApproxCountDistinctFunctionOnUnsupportedComplexColumn() + { + DruidException druidException = Assert.assertThrows( + DruidException.class, + () -> testQuery( + "SELECT APPROX_COUNT_DISTINCT_DS_THETA(double_first_added) FROM druid.wikipedia_first_last", + ImmutableList.of(), + ImmutableList.of() + ) + ); + Assert.assertTrue(druidException.getMessage().contains( + "Cannot apply 'APPROX_COUNT_DISTINCT_DS_THETA' to arguments of type 'APPROX_COUNT_DISTINCT_DS_THETA(>)'" + )); + } + @Test public void testThetaSketchPostAggs() { diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/hyperloglog/HyperUniquesAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/hyperloglog/HyperUniquesAggregatorFactory.java index c1c55a826b1..c90b651fffa 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/hyperloglog/HyperUniquesAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/hyperloglog/HyperUniquesAggregatorFactory.java @@ -21,8 +21,8 @@ package org.apache.druid.query.aggregation.hyperloglog; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.druid.error.DruidException; import org.apache.druid.hll.HyperLogLogCollector; -import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.guava.Comparators; import org.apache.druid.query.aggregation.AggregateCombiner; @@ -107,12 +107,8 @@ public class HyperUniquesAggregatorFactory extends AggregatorFactory if (selector instanceof NilColumnValueSelector) { return NoopAggregator.instance(); } - final Class classOfObject = selector.classOfObject(); - if (classOfObject.equals(Object.class) || HyperLogLogCollector.class.isAssignableFrom(classOfObject)) { - return new HyperUniquesAggregator(selector); - } - - throw new IAE("Incompatible type for metric[%s], expected a HyperUnique, got a %s", fieldName, classOfObject); + validateInputs(metricFactory.getColumnCapabilities(fieldName)); + return new HyperUniquesAggregator(selector); } @Override @@ -122,25 +118,43 @@ public class HyperUniquesAggregatorFactory extends AggregatorFactory if (selector instanceof NilColumnValueSelector) { return NoopBufferAggregator.instance(); } - final Class classOfObject = selector.classOfObject(); - if (classOfObject.equals(Object.class) || HyperLogLogCollector.class.isAssignableFrom(classOfObject)) { - return new HyperUniquesBufferAggregator(selector); - } - - throw new IAE("Incompatible type for metric[%s], expected a HyperUnique, got a %s", fieldName, classOfObject); + validateInputs(metricFactory.getColumnCapabilities(fieldName)); + return new HyperUniquesBufferAggregator(selector); } @Override public VectorAggregator factorizeVector(final VectorColumnSelectorFactory selectorFactory) { - final ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); - if (!Types.is(capabilities, ValueType.COMPLEX)) { + final ColumnCapabilities columnCapabilities = selectorFactory.getColumnCapabilities(fieldName); + if (!Types.is(columnCapabilities, ValueType.COMPLEX)) { return NoopVectorAggregator.instance(); } else { + validateInputs(columnCapabilities); return new HyperUniquesVectorAggregator(selectorFactory.makeObjectSelector(fieldName)); } } + /** + * Validates whether the aggregator supports the input column type. + * Supported column types are complex types of hyperUnique, preComputedHyperUnique, as well as UNKNOWN_COMPLEX. + * @param capabilities + */ + private void validateInputs(@Nullable ColumnCapabilities capabilities) + { + if (capabilities != null) { + final ColumnType type = capabilities.toColumnType(); + if (!(ColumnType.UNKNOWN_COMPLEX.equals(type) || TYPE.equals(type) || PRECOMPUTED_TYPE.equals(type))) { + throw DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.UNSUPPORTED) + .build( + "Using aggregator [%s] is not supported for complex columns with type [%s].", + getIntermediateType().getComplexTypeName(), + type + ); + } + } + } + @Override public boolean canVectorize(ColumnInspector columnInspector) { diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/hyperloglog/HyperUniquesAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/hyperloglog/HyperUniquesAggregatorFactoryTest.java index 421a457999d..c4df70a88de 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/hyperloglog/HyperUniquesAggregatorFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/hyperloglog/HyperUniquesAggregatorFactoryTest.java @@ -22,20 +22,39 @@ package org.apache.druid.query.aggregation.hyperloglog; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.hash.HashFunction; import com.google.common.hash.Hashing; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.error.DruidException; import org.apache.druid.hll.HyperLogLogCollector; import org.apache.druid.hll.VersionZeroHyperLogLogCollector; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.NoopAggregator; +import org.apache.druid.query.aggregation.NoopBufferAggregator; +import org.apache.druid.query.aggregation.NoopVectorAggregator; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.TestColumnSelectorFactory; import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.vector.TestVectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; import java.nio.ByteBuffer; import java.util.Comparator; import java.util.Random; +import static org.junit.jupiter.api.Assertions.assertThrows; + public class HyperUniquesAggregatorFactoryTest { + static { + NullHandling.initializeForTests(); + } + static final HyperUniquesAggregatorFactory AGGREGATOR_FACTORY = new HyperUniquesAggregatorFactory( "hyperUnique", "uniques" @@ -44,6 +63,19 @@ public class HyperUniquesAggregatorFactoryTest private final HashFunction fn = Hashing.murmur3_128(); + private ColumnSelectorFactory metricFactory; + private VectorColumnSelectorFactory vectorFactory; + + @Before + public void setup() + { + final ColumnCapabilitiesImpl columnCapabilities = ColumnCapabilitiesImpl.createDefault().setType(ColumnType.NESTED_DATA); + metricFactory = new TestColumnSelectorFactory() + .addCapabilities("uniques", columnCapabilities) + .addColumnSelector("uniques", null); + vectorFactory = new TestVectorColumnSelectorFactory().addCapabilities("uniques", columnCapabilities); + } + @Test public void testDeserializeV0() { @@ -216,4 +248,39 @@ public class HyperUniquesAggregatorFactoryTest Assert.assertEquals(factory, factory2); } + + @Test + public void testFactorizeOnPrimitiveColumnType() + { + final ColumnCapabilitiesImpl columnCapabilities = ColumnCapabilitiesImpl.createDefault().setType(ColumnType.LONG); + final ColumnSelectorFactory metricFactory = new TestColumnSelectorFactory() + .addCapabilities("uniques", columnCapabilities) + .addColumnSelector("uniques", NilColumnValueSelector.instance()); + final VectorColumnSelectorFactory vectorFactory = new TestVectorColumnSelectorFactory().addCapabilities("uniques", columnCapabilities); + + Assert.assertEquals(NoopAggregator.instance(), AGGREGATOR_FACTORY.factorize(metricFactory)); + Assert.assertEquals(NoopBufferAggregator.instance(), AGGREGATOR_FACTORY.factorizeBuffered(metricFactory)); + Assert.assertEquals(NoopVectorAggregator.instance(), AGGREGATOR_FACTORY.factorizeVector(vectorFactory)); + } + + @Test + public void testFactorizeOnUnsupportedComplexColumn() + { + Throwable exception = assertThrows(DruidException.class, () -> AGGREGATOR_FACTORY.factorize(metricFactory)); + Assert.assertEquals("Using aggregator [hyperUnique] is not supported for complex columns with type [COMPLEX].", exception.getMessage()); + } + + @Test + public void testFactorizeBufferedOnUnsupportedComplexColumn() + { + Throwable exception = assertThrows(DruidException.class, () -> AGGREGATOR_FACTORY.factorizeBuffered(metricFactory)); + Assert.assertEquals("Using aggregator [hyperUnique] is not supported for complex columns with type [COMPLEX].", exception.getMessage()); + } + + @Test + public void testFactorizeVectorOnUnsupportedComplexColumn() + { + Throwable exception = assertThrows(DruidException.class, () -> AGGREGATOR_FACTORY.factorizeVector(vectorFactory)); + Assert.assertEquals("Using aggregator [hyperUnique] is not supported for complex columns with type [COMPLEX].", exception.getMessage()); + } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java index eceb4ebbf80..a8bae969863 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java @@ -21,10 +21,7 @@ package org.apache.druid.sql.calcite.aggregation; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.InferTypes; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.Optionality; @@ -44,20 +41,20 @@ import java.util.List; */ public class ApproxCountDistinctSqlAggregator implements SqlAggregator { - private static final SqlAggFunction FUNCTION_INSTANCE = new ApproxCountDistinctSqlAggFunction(); private static final String NAME = "APPROX_COUNT_DISTINCT"; - + private final SqlAggFunction delegateFunction; private final SqlAggregator delegate; public ApproxCountDistinctSqlAggregator(final SqlAggregator delegate) { this.delegate = delegate; + this.delegateFunction = new ApproxCountDistinctSqlAggFunction(delegate.calciteFunction()); } @Override public SqlAggFunction calciteFunction() { - return FUNCTION_INSTANCE; + return delegateFunction; } @Nullable @@ -85,16 +82,16 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator private static class ApproxCountDistinctSqlAggFunction extends SqlAggFunction { - ApproxCountDistinctSqlAggFunction() + ApproxCountDistinctSqlAggFunction(SqlAggFunction delegate) { super( NAME, null, SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.BIGINT), - InferTypes.VARCHAR_1024, - OperandTypes.ANY, - SqlFunctionCategory.STRING, + delegate.getOperandTypeInference(), + delegate.getOperandTypeChecker(), + delegate.getFunctionType(), false, false, Optionality.FORBIDDEN diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java index 699c7a8d1c6..c756aa64cc3 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java @@ -46,13 +46,16 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PlannerConfig; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; +import org.apache.druid.sql.calcite.table.RowSignatures; import javax.annotation.Nullable; import java.util.Collections; import java.util.List; +import java.util.Objects; public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator { @@ -94,7 +97,7 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator if (arg.isDirectColumnAccess() && inputAccessor.getInputRowSignature() .getColumnType(arg.getDirectColumn()) - .map(type -> type.is(ValueType.COMPLEX)) + .map(this::isValidComplexInputType) .orElse(false)) { aggregatorFactory = new HyperUniquesAggregatorFactory(aggregatorName, arg.getDirectColumn(), false, true); } else { @@ -118,6 +121,15 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator } if (inputType.is(ValueType.COMPLEX)) { + if (!isValidComplexInputType(inputType)) { + plannerContext.setPlanningError( + "Using APPROX_COUNT_DISTINCT() or enabling approximation with COUNT(DISTINCT) is not supported for" + + " column type [%s]. You can disable approximation by setting [%s: false] in the query context.", + arg.getDruidType(), + PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT + ); + return null; + } aggregatorFactory = new HyperUniquesAggregatorFactory( aggregatorName, dimensionSpec.getOutputName(), @@ -151,7 +163,11 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.BIGINT), InferTypes.VARCHAR_1024, - OperandTypes.ANY, + OperandTypes.or( + OperandTypes.STRING, + OperandTypes.NUMERIC, + RowSignatures.complexTypeChecker(HyperUniquesAggregatorFactory.TYPE) + ), SqlFunctionCategory.STRING, false, false, @@ -159,4 +175,10 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator ); } } + + private boolean isValidComplexInputType(ColumnType columnType) + { + return Objects.equals(columnType.getComplexTypeName(), HyperUniquesAggregatorFactory.TYPE.getComplexTypeName()) || + Objects.equals(columnType.getComplexTypeName(), HyperUniquesAggregatorFactory.PRECOMPUTED_TYPE.getComplexTypeName()); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java index a5030450501..9e6e5da2bb0 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java @@ -80,6 +80,7 @@ import org.apache.druid.sql.calcite.util.TestDataBuilder; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.LinearShardSpec; import org.hamcrest.CoreMatchers; +import org.junit.Assert; import org.junit.internal.matchers.ThrowableMessageMatcher; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -7297,6 +7298,33 @@ public class CalciteNestedDataQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testApproxCountDistinctOnUnsupportedComplexColumn() + { + assertQueryIsUnplannable( + "SELECT COUNT(DISTINCT nester) FROM druid.nested", + "Query could not be planned. A possible reason is [Using APPROX_COUNT_DISTINCT() or enabling " + + "approximation with COUNT(DISTINCT) is not supported for column type [COMPLEX]. " + + "You can disable approximation by setting [useApproximateCountDistinct: false] in the query context." + ); + } + + @Test + public void testApproxCountDistinctFunctionOnUnsupportedComplexColumn() + { + DruidException druidException = Assert.assertThrows( + DruidException.class, + () -> testQuery( + "SELECT APPROX_COUNT_DISTINCT(nester) FROM druid.nested", + ImmutableList.of(), + ImmutableList.of() + ) + ); + Assert.assertTrue(druidException.getMessage().contains( + "Cannot apply 'APPROX_COUNT_DISTINCT' to arguments of type 'APPROX_COUNT_DISTINCT(>)'" + )); + } + @Test public void testNvlJsonValueDoubleSometimesMissingEqualityFilter() {