diff --git a/codestyle/druid-forbidden-apis.txt b/codestyle/druid-forbidden-apis.txt index a99654f1212..8da588f9d31 100644 --- a/codestyle/druid-forbidden-apis.txt +++ b/codestyle/druid-forbidden-apis.txt @@ -44,6 +44,7 @@ java.util.LinkedList @ Use ArrayList or ArrayDeque instead java.util.Random#() @ Use ThreadLocalRandom.current() or the constructor with a seed (the latter in tests only!) java.lang.Math#random() @ Use ThreadLocalRandom.current() java.util.regex.Pattern#matches(java.lang.String,java.lang.CharSequence) @ Use String.startsWith(), endsWith(), contains(), or compile and cache a Pattern explicitly +org.apache.calcite.sql.type.OperandTypes#LITERAL @ LITERAL type checker throws when literals with CAST are passed. Use org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker instead. org.apache.commons.io.FileUtils#getTempDirectory() @ Use org.junit.rules.TemporaryFolder for tests instead org.apache.commons.io.FileUtils#deleteDirectory(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils#deleteDirectory() org.apache.commons.io.FileUtils#forceMkdir(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils.mkdirp instead diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java index ebb6c7f4b14..cb9b95423a6 100644 --- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java +++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java @@ -25,10 +25,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorFactory; @@ -37,6 +37,7 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -133,8 +134,6 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE_WITH_COMPRESSION = "'" + NAME + "(column, compression)'"; - TDigestGenerateSketchSqlAggFunction() { super( @@ -143,16 +142,19 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.OTHER), null, - OperandTypes.or( - OperandTypes.ANY, - OperandTypes.and( - OperandTypes.sequence(SIGNATURE_WITH_COMPRESSION, OperandTypes.ANY, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) - ) - ), + // Validation for signatures like 'TDIGEST_GENERATE_SKETCH(column)' and + // 'TDIGEST_GENERATE_SKETCH(column, compression)' + DefaultOperandTypeChecker + .builder() + .operandNames("column", "compression") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + .requiredOperandCount(1) + .literalOperands(1) + .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, - false + false, + Optionality.FORBIDDEN ); } } diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java index ee63444f6d7..7cc0bdfe8c5 100644 --- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java +++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java @@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; @@ -40,6 +40,7 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -158,9 +159,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator private static class TDigestSketchQuantileSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE1 = "'" + NAME + "(column, quantile)'"; - private static final String SIGNATURE2 = "'" + NAME + "(column, quantile, compression)'"; - TDigestSketchQuantileSqlAggFunction() { super( @@ -169,19 +167,18 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.DOUBLE), null, - OperandTypes.or( - OperandTypes.and( - OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) - ), - OperandTypes.and( - OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC) - ) - ), + // Accounts for both 'TDIGEST_QUANTILE(column, quantile)' and 'TDIGEST_QUANTILE(column, quantile, compression)' + DefaultOperandTypeChecker + .builder() + .operandNames("column", "quantile", "compression") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC) + .literalOperands(1, 2) + .requiredOperandCount(2) + .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, - false + false, + Optionality.FORBIDDEN ); } } diff --git a/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java b/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java index c9dba876276..6193b82cbdd 100644 --- a/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java +++ b/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java @@ -138,6 +138,76 @@ public class TDigestSketchSqlAggregatorTest extends BaseCalciteQueryTest ); } + @Test + public void testCastedQuantileAndCompressionParamForTDigestQuantileAgg() + { + cannotVectorize(); + testQuery( + "SELECT\n" + + "TDIGEST_QUANTILE(m1, CAST(0.0 AS DOUBLE)), " + + "TDIGEST_QUANTILE(m1, CAST(0.5 AS FLOAT), CAST(200 AS INTEGER)), " + + "TDIGEST_QUANTILE(m1, CAST(1.0 AS DOUBLE), 300)\n" + + "FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(Granularities.ALL) + .aggregators(ImmutableList.of( + new TDigestSketchAggregatorFactory("a0:agg", "m1", + TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION + ), + new TDigestSketchAggregatorFactory("a1:agg", "m1", + 200 + ), + new TDigestSketchAggregatorFactory("a2:agg", "m1", + 300 + ) + )) + .postAggregators( + new TDigestSketchToQuantilePostAggregator("a0", makeFieldAccessPostAgg("a0:agg"), 0.0f), + new TDigestSketchToQuantilePostAggregator("a1", makeFieldAccessPostAgg("a1:agg"), 0.5f), + new TDigestSketchToQuantilePostAggregator("a2", makeFieldAccessPostAgg("a2:agg"), 1.0f) + ) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ResultMatchMode.EQUALS_EPS, + ImmutableList.of( + new Object[]{1.0, 3.5, 6.0} + ) + ); + } + + @Test + public void testComputingSketchOnNumericValuesWithCastedCompressionParameter() + { + cannotVectorize(); + + testQuery( + "SELECT\n" + + "TDIGEST_GENERATE_SKETCH(m1, CAST(200 AS INTEGER))" + + "FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(Granularities.ALL) + .aggregators(ImmutableList.of( + new TDigestSketchAggregatorFactory("a0:agg", "m1", 200) + )) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ResultMatchMode.EQUALS_EPS, + ImmutableList.of( + new String[]{ + "\"AAAAAT/wAAAAAAAAQBgAAAAAAABAaQAAAAAAAAAAAAY/8AAAAAAAAD/wAAAAAAAAP/AAAAAAAABAAAAAAAAAAD/wAAAAAAAAQAgAAAAAAAA/8AAAAAAAAEAQAAAAAAAAP/AAAAAAAABAFAAAAAAAAD/wAAAAAAAAQBgAAAAAAAA=\"" + } + ) + ); + } + @Test public void testComputingSketchOnCastedString() { diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchListArgBaseOperatorConversion.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchListArgBaseOperatorConversion.java index a83d937f680..337586b1f94 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchListArgBaseOperatorConversion.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchListArgBaseOperatorConversion.java @@ -40,7 +40,6 @@ import org.apache.calcite.util.Static; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.segment.column.RowSignature; -import org.apache.druid.sql.calcite.expression.BasicOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.PostAggregatorVisitor; @@ -143,8 +142,8 @@ public abstract class DoublesSketchListArgBaseOperatorConversion implements SqlO final RelDataType operandType = callBinding.getValidator().deriveType(callBinding.getScope(), operand); // Verify that 'operand' is a literal number. - if (!SqlUtil.isLiteral(operand)) { - return BasicOperandTypeChecker.throwOrReturn( + if (!SqlUtil.isLiteral(operand, true)) { + return OperatorConversions.throwOrReturn( throwOnFailure, callBinding, cb -> cb.getValidator() @@ -156,7 +155,7 @@ public abstract class DoublesSketchListArgBaseOperatorConversion implements SqlO } if (!SqlTypeFamily.NUMERIC.contains(operandType)) { - return BasicOperandTypeChecker.throwOrReturn( + return OperatorConversions.throwOrReturn( throwOnFailure, callBinding, SqlCallBinding::newValidationSignatureError diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java index f85225d107d..b39d3441b65 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java @@ -495,6 +495,7 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest + " DS_GET_QUANTILE(DS_QUANTILES_SKETCH(cnt + 123), 0.5) + 1000,\n" + " ABS(DS_GET_QUANTILE(DS_QUANTILES_SKETCH(cnt), 0.5)),\n" + " DS_GET_QUANTILES(DS_QUANTILES_SKETCH(cnt), 0.5, 0.8),\n" + + " DS_GET_QUANTILES(DS_QUANTILES_SKETCH(cnt), CAST(0.5 AS DOUBLE), CAST(0.8 AS DOUBLE)),\n" + " DS_HISTOGRAM(DS_QUANTILES_SKETCH(cnt), 0.2, 0.6),\n" + " DS_RANK(DS_QUANTILES_SKETCH(cnt), 3),\n" + " DS_CDF(DS_QUANTILES_SKETCH(cnt), 0.2, 0.6),\n" @@ -588,41 +589,49 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest ), new double[]{0.5d, 0.8d} ), - new DoublesSketchToHistogramPostAggregator( + new DoublesSketchToQuantilesPostAggregator( "p13", new FieldAccessPostAggregator( "p12", "a2:agg" ), + new double[]{0.5d, 0.8d} + ), + new DoublesSketchToHistogramPostAggregator( + "p15", + new FieldAccessPostAggregator( + "p14", + "a2:agg" + ), new double[]{0.2d, 0.6d}, null ), new DoublesSketchToRankPostAggregator( - "p15", - new FieldAccessPostAggregator( - "p14", - "a2:agg" - ), - 3.0d - ), - new DoublesSketchToCDFPostAggregator( "p17", new FieldAccessPostAggregator( "p16", "a2:agg" ), - new double[]{0.2d, 0.6d} + 3.0d ), - new DoublesSketchToStringPostAggregator( + new DoublesSketchToCDFPostAggregator( "p19", new FieldAccessPostAggregator( "p18", "a2:agg" + ), + new double[]{0.2d, 0.6d} + ), + new DoublesSketchToStringPostAggregator( + "p21", + new FieldAccessPostAggregator( + "p20", + "a2:agg" ) ), new ExpressionPostAggregator( - "p20", - "replace(replace(\"p19\",'HeapCompactDoublesSketch','HeapUpdateDoublesSketch')," + "p22", + "replace(replace(\"p21\",'HeapCompactDoublesSketch','HeapUpdateDoublesSketch')," + "'Combined Buffer Capacity : 6'," + "'Combined Buffer Capacity : 8')", null, @@ -640,6 +649,7 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest 1124.0d, 1.0d, "[1.0,1.0]", + "[1.0,1.0]", "[0.0,0.0,6.0]", 1.0d, "[0.0,0.0,1.0]", diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java index 6a1ca49067e..cb73d94ef7e 100644 --- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java +++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java @@ -25,10 +25,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.bloom.BloomFilterAggregatorFactory; @@ -38,6 +38,7 @@ import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; @@ -168,8 +169,6 @@ public class BloomFilterSqlAggregator implements SqlAggregator private static class BloomFilterSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE1 = "'" + NAME + "(column, maxNumEntries)'"; - BloomFilterSqlAggFunction() { super( @@ -178,13 +177,18 @@ public class BloomFilterSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.OTHER), null, - OperandTypes.and( - OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) - ), + // Allow signatures like 'BLOOM_FILTER(column, maxNumEntries)' + DefaultOperandTypeChecker + .builder() + .operandNames("column", "maxNumEntries") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + .literalOperands(1) + .requiredOperandCount(2) + .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, - false + false, + Optionality.FORBIDDEN ); } } diff --git a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java index 1c77f0986e1..eff2039ffa8 100644 --- a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java +++ b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java @@ -125,7 +125,8 @@ public class BloomFilterSqlAggregatorTest extends BaseCalciteQueryTest testQuery( "SELECT\n" - + "BLOOM_FILTER(dim1, 1000)\n" + + "BLOOM_FILTER(dim1, 1000),\n" + + "BLOOM_FILTER(dim1, CAST(1000 AS INTEGER))\n" + "FROM numfoo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -145,7 +146,10 @@ public class BloomFilterSqlAggregatorTest extends BaseCalciteQueryTest .build() ), ImmutableList.of( - new Object[]{queryFramework().queryJsonMapper().writeValueAsString(expected1)} + new Object[]{ + queryFramework().queryJsonMapper().writeValueAsString(expected1), + queryFramework().queryJsonMapper().writeValueAsString(expected1) + } ) ); } diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java index fdc61796c4d..eceaa7b8ad2 100644 --- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java @@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogram; @@ -39,6 +39,7 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -221,15 +222,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator private static class FixedBucketsHistogramQuantileSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE1 = - "'" - + NAME - + "(column, probability, numBuckets, lowerLimit, upperLimit)'"; - private static final String SIGNATURE2 = - "'" - + NAME - + "(column, probability, numBuckets, lowerLimit, upperLimit, outlierHandlingMode)'"; - FixedBucketsHistogramQuantileSqlAggFunction() { super( @@ -238,47 +230,33 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.DOUBLE), null, - OperandTypes.or( - OperandTypes.and( - OperandTypes.sequence( - SIGNATURE1, - OperandTypes.ANY, - OperandTypes.LITERAL, - OperandTypes.LITERAL, - OperandTypes.LITERAL, - OperandTypes.LITERAL - ), - OperandTypes.family( - SqlTypeFamily.ANY, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC - ) - ), - OperandTypes.and( - OperandTypes.sequence( - SIGNATURE2, - OperandTypes.ANY, - OperandTypes.LITERAL, - OperandTypes.LITERAL, - OperandTypes.LITERAL, - OperandTypes.LITERAL, - OperandTypes.LITERAL - ), - OperandTypes.family( - SqlTypeFamily.ANY, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.STRING - ) + // Allows signatures like 'APPROX_QUANTILE_FIXED_BUCKETS(column, probability, numBuckets, lowerLimit, upperLimit)' + // and 'APPROX_QUANTILE_FIXED_BUCKETS(column, probability, numBuckets, lowerLimit, upperLimit, outlierHandlingMode)' + DefaultOperandTypeChecker + .builder() + .operandNames( + "column", + "probability", + "numBuckets", + "lowerLimit", + "upperLimit", + "outlierHandlingMode" ) - ), + .operandTypes( + SqlTypeFamily.ANY, + SqlTypeFamily.NUMERIC, + SqlTypeFamily.NUMERIC, + SqlTypeFamily.NUMERIC, + SqlTypeFamily.NUMERIC, + SqlTypeFamily.STRING + ) + .literalOperands(1, 2, 3, 4, 5) + .requiredOperandCount(5) + .build(), SqlFunctionCategory.NUMERIC, false, - false + false, + Optionality.FORBIDDEN ); } } diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java index a3fe8dc5458..41df080147b 100644 --- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java @@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.histogram.ApproximateHistogram; @@ -41,6 +41,7 @@ import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -196,9 +197,6 @@ public class QuantileSqlAggregator implements SqlAggregator private static class QuantileSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE1 = "'" + NAME + "(column, probability)'"; - private static final String SIGNATURE2 = "'" + NAME + "(column, probability, resolution)'"; - QuantileSqlAggFunction() { super( @@ -207,19 +205,19 @@ public class QuantileSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.DOUBLE), null, - OperandTypes.or( - OperandTypes.and( - OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) - ), - OperandTypes.and( - OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC) - ) - ), + // Checks for signatures like 'APPROX_QUANTILE(column, probability)' and + // 'APPROX_QUANTILE(column, probability, resolution)' + DefaultOperandTypeChecker + .builder() + .operandNames("column", "probability", "resolution") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC) + .requiredOperandCount(2) + .literalOperands(1, 2) + .build(), SqlFunctionCategory.NUMERIC, false, - false + false, + Optionality.FORBIDDEN ); } } diff --git a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java index d0a12ab2f2d..c54231c86d8 100644 --- a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java +++ b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java @@ -128,6 +128,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ 6.494999885559082, 5.497499942779541, 6.499499797821045, + 6.499499797821045, + 6.499499797821045, + 6.499499797821045, 1.25 } ); @@ -142,6 +145,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.99, 20, 0.0, 10.0) FILTER(WHERE dim1 = 'abc'),\n" + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0) FILTER(WHERE dim1 <> 'abc'),\n" + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0) FILTER(WHERE dim1 = 'abc'),\n" + + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'ignore') FILTER(WHERE dim1 = 'abc'),\n" + + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'clip') FILTER(WHERE dim1 = 'abc'),\n" + + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'overflow') FILTER(WHERE dim1 = 'abc'),\n" + "APPROX_QUANTILE_FIXED_BUCKETS(cnt, 0.5, 20, 0.0, 10.0)\n" + "FROM foo", ImmutableList.of( @@ -200,8 +206,32 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ ), not(equality("dim1", "abc", ColumnType.STRING)) ), + new FilteredAggregatorFactory( + new FixedBucketsHistogramAggregatorFactory( + "a9:agg", + "m1", + 20, + 0.0d, + 10.0d, + FixedBucketsHistogram.OutlierHandlingMode.CLIP, + false + ), + equality("dim1", "abc", ColumnType.STRING) + ), + new FilteredAggregatorFactory( + new FixedBucketsHistogramAggregatorFactory( + "a10:agg", + "m1", + 20, + 0.0d, + 10.0d, + FixedBucketsHistogram.OutlierHandlingMode.OVERFLOW, + false + ), + equality("dim1", "abc", ColumnType.STRING) + ), new FixedBucketsHistogramAggregatorFactory( - "a8:agg", + "a11:agg", "cnt", 20, 0.0d, @@ -219,7 +249,55 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ new QuantilePostAggregator("a5", "a5:agg", 0.99f), new QuantilePostAggregator("a6", "a6:agg", 0.999f), new QuantilePostAggregator("a7", "a5:agg", 0.999f), - new QuantilePostAggregator("a8", "a8:agg", 0.50f) + new QuantilePostAggregator("a8", "a5:agg", 0.999f), + new QuantilePostAggregator("a9", "a9:agg", 0.999f), + new QuantilePostAggregator("a10", "a10:agg", 0.999f), + new QuantilePostAggregator("a11", "a11:agg", 0.50f) + ) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + expectedResults + ); + } + + @Test + public void testQuantileWithCastedLiteralArguments() + { + final List expectedResults = ImmutableList.of(new Object[]{6.499499797821045}); + testQuery( + "SELECT\n" + + "APPROX_QUANTILE_FIXED_BUCKETS(" + + "m1, " + + "CAST(0.999 AS DOUBLE), " + + "CAST(20 AS INTEGER), " + + "CAST(0.0 AS DOUBLE), " + + "CAST(10.0 AS DOUBLE), " + + "CAST('overflow' AS VARCHAR)" + + ") " + + "FILTER(WHERE dim1 = 'abc')\n" + + "FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(Granularities.ALL) + .aggregators(ImmutableList.of( + new FilteredAggregatorFactory( + new FixedBucketsHistogramAggregatorFactory( + "a0:agg", + "m1", + 20, + 0.0d, + 10.0d, + FixedBucketsHistogram.OutlierHandlingMode.OVERFLOW, + false + ), + equality("dim1", "abc", ColumnType.STRING) + ) + )) + .postAggregators( + new QuantilePostAggregator("a0", "a0:agg", 0.999f) ) .context(QUERY_CONTEXT_DEFAULT) .build() diff --git a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java index 9e36def53cb..6bcd26d23c1 100644 --- a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java +++ b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java @@ -121,6 +121,7 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest "SELECT\n" + "APPROX_QUANTILE(m1, 0.01),\n" + "APPROX_QUANTILE(m1, 0.5, 50),\n" + + "APPROX_QUANTILE(m1, CAST(0.5 AS DOUBLE), CAST(50 AS INTEGER)),\n" + "APPROX_QUANTILE(m1, 0.98, 200),\n" + "APPROX_QUANTILE(m1, 0.99),\n" + "APPROX_QUANTILE(m1 * 2, 0.97),\n" @@ -144,28 +145,29 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest ) .aggregators(ImmutableList.of( new ApproximateHistogramAggregatorFactory("a0:agg", "m1", null, null, null, null, false), - new ApproximateHistogramAggregatorFactory("a2:agg", "m1", 200, null, null, null, false), - new ApproximateHistogramAggregatorFactory("a4:agg", "v0", null, null, null, null, false), + new ApproximateHistogramAggregatorFactory("a3:agg", "m1", 200, null, null, null, false), + new ApproximateHistogramAggregatorFactory("a5:agg", "v0", null, null, null, null, false), new FilteredAggregatorFactory( - new ApproximateHistogramAggregatorFactory("a5:agg", "m1", null, null, null, null, false), + new ApproximateHistogramAggregatorFactory("a6:agg", "m1", null, null, null, null, false), equality("dim1", "abc", ColumnType.STRING) ), new FilteredAggregatorFactory( - new ApproximateHistogramAggregatorFactory("a6:agg", "m1", null, null, null, null, false), + new ApproximateHistogramAggregatorFactory("a7:agg", "m1", null, null, null, null, false), not(equality("dim1", "abc", ColumnType.STRING)) ), - new ApproximateHistogramAggregatorFactory("a8:agg", "cnt", null, null, null, null, false) + new ApproximateHistogramAggregatorFactory("a9:agg", "cnt", null, null, null, null, false) )) .postAggregators( new QuantilePostAggregator("a0", "a0:agg", 0.01f), new QuantilePostAggregator("a1", "a0:agg", 0.50f), - new QuantilePostAggregator("a2", "a2:agg", 0.98f), - new QuantilePostAggregator("a3", "a0:agg", 0.99f), - new QuantilePostAggregator("a4", "a4:agg", 0.97f), - new QuantilePostAggregator("a5", "a5:agg", 0.99f), - new QuantilePostAggregator("a6", "a6:agg", 0.999f), - new QuantilePostAggregator("a7", "a5:agg", 0.999f), - new QuantilePostAggregator("a8", "a8:agg", 0.50f) + new QuantilePostAggregator("a2", "a0:agg", 0.50f), + new QuantilePostAggregator("a3", "a3:agg", 0.98f), + new QuantilePostAggregator("a4", "a0:agg", 0.99f), + new QuantilePostAggregator("a5", "a5:agg", 0.97f), + new QuantilePostAggregator("a6", "a6:agg", 0.99f), + new QuantilePostAggregator("a7", "a7:agg", 0.999f), + new QuantilePostAggregator("a8", "a6:agg", 0.999f), + new QuantilePostAggregator("a9", "a9:agg", 0.50f) ) .context(QUERY_CONTEXT_DEFAULT) .build() @@ -174,6 +176,7 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest new Object[]{ 1.0, 3.0, + 3.0, 5.880000114440918, 5.940000057220459, 11.640000343322754, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java index 2e031616027..efa3a9e7e32 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java @@ -33,8 +33,8 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.InferTypes; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.sql.util.SqlVisitor; @@ -61,6 +61,7 @@ import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; @@ -369,14 +370,13 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, EARLIEST_LATEST_ARG0_RETURN_TYPE_INFERENCE, InferTypes.RETURN_TYPE, - OperandTypes.or( - OperandTypes.ANY, - OperandTypes.sequence( - "'" + aggregatorType.name() + "(expr, maxBytesPerString)'", - OperandTypes.ANY, - OperandTypes.and(OperandTypes.NUMERIC, OperandTypes.LITERAL) - ) - ), + DefaultOperandTypeChecker + .builder() + .operandNames("expr", "maxBytesPerString") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + .requiredOperandCount(1) + .literalOperands(1) + .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, false, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java index c12be459cf5..03e23503a81 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java @@ -26,7 +26,6 @@ import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.InferTypes; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.util.Optionality; @@ -38,6 +37,7 @@ import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregat import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; @@ -176,19 +176,13 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, EARLIEST_LATEST_ARG0_RETURN_TYPE_INFERENCE, InferTypes.RETURN_TYPE, - OperandTypes.or( - OperandTypes.sequence( - "'" + StringUtils.format("%s_BY", aggregatorType.name()) + "(expr, timeColumn)'", - OperandTypes.ANY, - OperandTypes.family(SqlTypeFamily.TIMESTAMP) - ), - OperandTypes.sequence( - "'" + StringUtils.format("%s_BY", aggregatorType.name()) + "(expr, timeColumn, maxBytesPerString)'", - OperandTypes.ANY, - OperandTypes.family(SqlTypeFamily.TIMESTAMP), - OperandTypes.and(OperandTypes.NUMERIC, OperandTypes.LITERAL) - ) - ), + DefaultOperandTypeChecker + .builder() + .operandNames("expr", "timeColumn", "maxBytesPerString") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.NUMERIC) + .requiredOperandCount(2) + .literalOperands(2) + .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, false, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/BasicOperandTypeChecker.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java similarity index 62% rename from sql/src/main/java/org/apache/druid/sql/calcite/expression/BasicOperandTypeChecker.java rename to sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java index 177d4447f4e..f43fde3a935 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/BasicOperandTypeChecker.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java @@ -24,7 +24,6 @@ import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSets; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.runtime.CalciteException; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperandCountRange; @@ -35,35 +34,47 @@ import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.Static; +import org.apache.druid.java.util.common.ISE; import javax.annotation.Nullable; import java.util.Arrays; +import java.util.Collections; import java.util.List; -import java.util.function.Function; import java.util.stream.IntStream; /** - * Operand type checker that is used in simple situations: there are a particular number of operands, with + * Operand type checker that is used in 'simple' situations: there are a particular number of operands, with * particular types, some of which may be optional or nullable, and some of which may be required to be literals. */ -public class BasicOperandTypeChecker implements SqlOperandTypeChecker +public class DefaultOperandTypeChecker implements SqlOperandTypeChecker { + /** + * Operand names for {@link #getAllowedSignatures(SqlOperator, String)}. May be empty, in which case the + * {@link #operandTypes} are used instead. + */ + private final List operandNames; private final List operandTypes; private final int requiredOperands; - private final IntSet nullOperands; + private final IntSet nullableOperands; private final IntSet literalOperands; - BasicOperandTypeChecker( + private DefaultOperandTypeChecker( + final List operandNames, final List operandTypes, final int requiredOperands, - final IntSet nullOperands, + final IntSet nullableOperands, @Nullable final int[] literalOperands ) { Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0); + this.operandNames = Preconditions.checkNotNull(operandNames, "operandNames"); this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes"); this.requiredOperands = requiredOperands; - this.nullOperands = Preconditions.checkNotNull(nullOperands, "nullOperands"); + this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands"); + + if (!operandNames.isEmpty() && operandNames.size() != operandTypes.size()) { + throw new ISE("Operand name count[%s] and type count[%s] must match", operandNames.size(), operandTypes.size()); + } if (literalOperands == null) { this.literalOperands = IntSets.EMPTY_SET; @@ -78,19 +89,6 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker return new Builder(); } - public static boolean throwOrReturn( - final boolean throwOnFailure, - final SqlCallBinding callBinding, - final Function exceptionMapper - ) - { - if (throwOnFailure) { - throw exceptionMapper.apply(callBinding); - } else { - return false; - } - } - @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { @@ -98,9 +96,9 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker final SqlNode operand = callBinding.operands().get(i); if (literalOperands.contains(i)) { - // Verify that 'operand' is a literal. - if (!SqlUtil.isLiteral(operand)) { - return throwOrReturn( + // Verify that 'operand' is a literal. Allow CAST, since we can reduce these away later. + if (!SqlUtil.isLiteral(operand, true)) { + return OperatorConversions.throwOrReturn( throwOnFailure, callBinding, cb -> cb.getValidator() @@ -121,15 +119,15 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker // Operand came in with one of the expected types. } else if (operandType.getSqlTypeName() == SqlTypeName.NULL || SqlUtil.isNullLiteral(operand, true)) { // Null came in, check if operand is a nullable type. - if (!nullOperands.contains(i)) { - return throwOrReturn( + if (!nullableOperands.contains(i)) { + return OperatorConversions.throwOrReturn( throwOnFailure, callBinding, cb -> cb.getValidator().newValidationError(operand, Static.RESOURCE.nullIllegal()) ); } } else { - return throwOrReturn( + return OperatorConversions.throwOrReturn( throwOnFailure, callBinding, SqlCallBinding::newValidationSignatureError @@ -149,7 +147,25 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker @Override public String getAllowedSignatures(SqlOperator op, String opName) { - return SqlUtil.getAliasedSignature(op, opName, operandTypes); + final List operands = !operandNames.isEmpty() ? operandNames : operandTypes; + final StringBuilder ret = new StringBuilder(); + ret.append("'"); + ret.append(opName); + ret.append("("); + for (int i = 0; i < operands.size(); i++) { + if (i > 0) { + ret.append(", "); + } + if (i >= requiredOperands) { + ret.append("["); + } + ret.append("<").append(operands.get(i)).append(">"); + } + for (int i = requiredOperands; i < operands.size(); i++) { + ret.append("]"); + } + ret.append(")'"); + return ret.toString(); } @Override @@ -166,64 +182,70 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker public static class Builder { + private List operandNames = Collections.emptyList(); private List operandTypes; - private Integer requiredOperandCount = null; - private int[] literalOperands = null; - /** - * Signifies that a function accepts operands of type family given by {@param operandTypes}. - */ + @Nullable + private Integer requiredOperandCount; + private int[] literalOperands; + + private Builder() + { + } + + public Builder operandNames(final String... operandNames) + { + this.operandNames = Arrays.asList(operandNames); + return this; + } + + public Builder operandNames(final List operandNames) + { + this.operandNames = operandNames; + return this; + } + public Builder operandTypes(final SqlTypeFamily... operandTypes) { this.operandTypes = Arrays.asList(operandTypes); return this; } - /** - * Signifies that a function accepts operands of type family given by {@param operandTypes}. - */ public Builder operandTypes(final List operandTypes) { this.operandTypes = operandTypes; return this; } - /** - * Signifies that the first {@code requiredOperands} operands are required, and all later operands are optional. - * - * Required operands are not allowed to be null. Optional operands can either be skipped or explicitly provided as - * literal NULLs. For example, if {@code requiredOperands == 1}, then {@code F(x, NULL)} and {@code F(x)} are both - * accepted, and {@code x} must not be null. - */ - public Builder requiredOperandCount(final int requiredOperandCount) + public Builder requiredOperandCount(Integer requiredOperandCount) { this.requiredOperandCount = requiredOperandCount; return this; } - /** - * Signifies that the operands at positions given by {@code literalOperands} must be literals. - */ public Builder literalOperands(final int... literalOperands) { this.literalOperands = literalOperands; return this; } - public BasicOperandTypeChecker build() + public DefaultOperandTypeChecker build() { - // Create "nullableOperands" set including all optional arguments. - final IntSet nullableOperands = new IntArraySet(); - if (requiredOperandCount != null) { - IntStream.range(requiredOperandCount, operandTypes.size()).forEach(nullableOperands::add); - } - - return new BasicOperandTypeChecker( + int computedRequiredOperandCount = requiredOperandCount == null ? operandTypes.size() : requiredOperandCount; + return new DefaultOperandTypeChecker( + operandNames, operandTypes, - requiredOperandCount == null ? operandTypes.size() : requiredOperandCount, - nullableOperands, + computedRequiredOperandCount, + DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size()), literalOperands ); } } + + public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount) + { + final IntSet nullableOperands = new IntArraySet(); + IntStream.range(requiredOperandCount, totalOperandCount).forEach(nullableOperands::add); + return nullableOperands; + } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java index e8a2b796a26..39137b77aeb 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java @@ -19,13 +19,11 @@ package org.apache.druid.sql.calcite.expression; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntSet; -import it.unimi.dsi.fastutil.ints.IntSets; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; @@ -37,13 +35,9 @@ import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.SqlUtil; import org.apache.calcite.sql.type.BasicSqlType; import org.apache.calcite.sql.type.ReturnTypes; -import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; @@ -51,7 +45,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.calcite.util.Optionality; -import org.apache.calcite.util.Static; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; @@ -69,7 +63,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.function.Function; -import java.util.stream.IntStream; /** * Utilities for assisting in writing {@link SqlOperatorConversion} implementations. @@ -334,6 +327,13 @@ public class OperatorConversions return new AggregatorBuilder(name); } + /** + * Helps in creating the operator builder along with validations and type inference for simple operator conversions. + * + * The type checker for this operator conversion can be either supplied manually, or an instance of + * {@link DefaultOperandTypeChecker} will be used if the user passes in {@link #operandTypes} and other optional + * parameters. Exactly one of them must be supplied to the builder. + */ public static class OperatorBuilder { protected final String name; @@ -491,8 +491,8 @@ public class OperatorConversions } /** - * Equivalent to calling {@link BasicOperandTypeChecker.Builder#operandTypes(SqlTypeFamily...)}; leads to using a - * {@link BasicOperandTypeChecker} as our operand type checker. + * Equivalent to calling {@link DefaultOperandTypeChecker.Builder#operandTypes(SqlTypeFamily...)}; leads to using a + * {@link DefaultOperandTypeChecker} as our operand type checker. * * May be used in conjunction with {@link #requiredOperandCount(int)} and {@link #literalOperands(int...)} in order * to further refine operand checking logic. @@ -506,12 +506,9 @@ public class OperatorConversions } /** - * Equivalent to calling {@link BasicOperandTypeChecker.Builder#requiredOperandCount(int)}; leads to using a - * {@link BasicOperandTypeChecker} as our operand type checker. - * - * Not compatible with {@link #operandTypeChecker(SqlOperandTypeChecker)}. + * Equivalent to calling {@link DefaultOperandTypeChecker.Builder#requiredOperandCount(Integer)}; leads to using a + * {@link DefaultOperandTypeChecker} as our operand type checker. */ - @Deprecated public OperatorBuilder requiredOperandCount(final int requiredOperandCount) { this.requiredOperandCount = requiredOperandCount; @@ -531,8 +528,8 @@ public class OperatorConversions } /** - * Equivalent to calling {@link BasicOperandTypeChecker.Builder#literalOperands(int...)}; leads to using a - * {@link BasicOperandTypeChecker} as our operand type checker. + * Equivalent to calling {@link DefaultOperandTypeChecker.Builder#literalOperands(int...)}; leads to using a + * {@link DefaultOperandTypeChecker} as our operand type checker. * * Not compatible with {@link #operandTypeChecker(SqlOperandTypeChecker)}. */ @@ -554,37 +551,30 @@ public class OperatorConversions @SuppressWarnings("unchecked") public T build() { - final IntSet nullableOperands = buildNullableOperands(); return (T) new SqlFunction( name, kind, Preconditions.checkNotNull(returnTypeInference, "returnTypeInference"), - buildOperandTypeInference(nullableOperands), - buildOperandTypeChecker(nullableOperands), + buildOperandTypeInference(), + buildOperandTypeChecker(), functionCategory ); } - protected IntSet buildNullableOperands() - { - // Create "nullableOperands" set including all optional arguments. - final IntSet nullableOperands = new IntArraySet(); - if (requiredOperandCount != null) { - IntStream.range(requiredOperandCount, operandTypes.size()).forEach(nullableOperands::add); - } - return nullableOperands; - } - - protected SqlOperandTypeChecker buildOperandTypeChecker(final IntSet nullableOperands) + protected SqlOperandTypeChecker buildOperandTypeChecker() { if (operandTypeChecker == null) { - return new DefaultOperandTypeChecker( - operandNames, - operandTypes, - requiredOperandCount == null ? operandTypes.size() : requiredOperandCount, - nullableOperands, - literalOperands - ); + if (operandTypes == null) { + throw DruidException.defensive( + "'operandTypes' must be non null if 'operandTypeChecker' is not passed to the operator conversion."); + } + return DefaultOperandTypeChecker + .builder() + .operandNames(operandNames) + .operandTypes(operandTypes) + .requiredOperandCount(requiredOperandCount == null ? operandTypes.size() : requiredOperandCount) + .literalOperands(literalOperands) + .build(); } else if (operandNames.isEmpty() && operandTypes == null && requiredOperandCount == null @@ -598,8 +588,11 @@ public class OperatorConversions } } - protected SqlOperandTypeInference buildOperandTypeInference(final IntSet nullableOperands) + protected SqlOperandTypeInference buildOperandTypeInference() { + final IntSet nullableOperands = requiredOperandCount == null + ? new IntArraySet() + : DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size()); if (operandTypeInference == null) { SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands); return (callBinding, returnType, types) -> { @@ -634,9 +627,8 @@ public class OperatorConversions @Override public SqlAggFunction build() { - final IntSet nullableOperands = buildNullableOperands(); - final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference(nullableOperands); - final SqlOperandTypeChecker operandTypeChecker = buildOperandTypeChecker(nullableOperands); + final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference(); + final SqlOperandTypeChecker operandTypeChecker = buildOperandTypeChecker(); class DruidSqlAggFunction extends SqlAggFunction { @@ -735,147 +727,6 @@ public class OperatorConversions } } - /** - * Operand type checker that is used in 'simple' situations: there are a particular number of operands, with - * particular types, some of which may be optional or nullable, and some of which may be required to be literals. - */ - @VisibleForTesting - static class DefaultOperandTypeChecker implements SqlOperandTypeChecker - { - /** - * Operand names for {@link #getAllowedSignatures(SqlOperator, String)}. May be empty, in which case the - * {@link #operandTypes} are used instead. - */ - private final List operandNames; - private final List operandTypes; - private final int requiredOperands; - private final IntSet nullableOperands; - private final IntSet literalOperands; - - public int getNumberOfLiteralOperands() - { - return literalOperands.size(); - } - - @VisibleForTesting - DefaultOperandTypeChecker( - final List operandNames, - final List operandTypes, - final int requiredOperands, - final IntSet nullableOperands, - @Nullable final int[] literalOperands - ) - { - Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0); - this.operandNames = Preconditions.checkNotNull(operandNames, "operandNames"); - this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes"); - this.requiredOperands = requiredOperands; - this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands"); - - if (!operandNames.isEmpty() && operandNames.size() != operandTypes.size()) { - throw new ISE("Operand name count[%s] and type count[%s] must match", operandNames.size(), operandTypes.size()); - } - - if (literalOperands == null) { - this.literalOperands = IntSets.EMPTY_SET; - } else { - this.literalOperands = new IntArraySet(); - Arrays.stream(literalOperands).forEach(this.literalOperands::add); - } - } - - @Override - public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) - { - for (int i = 0; i < callBinding.operands().size(); i++) { - final SqlNode operand = callBinding.operands().get(i); - - if (literalOperands.contains(i)) { - // Verify that 'operand' is a literal. Allow CAST, since we can reduce these away later. - if (!SqlUtil.isLiteral(operand, true)) { - return throwOrReturn( - throwOnFailure, - callBinding, - cb -> cb.getValidator() - .newValidationError( - operand, - Static.RESOURCE.argumentMustBeLiteral(callBinding.getOperator().getName()) - ) - ); - } - } - - final RelDataType operandType = callBinding.getValidator().deriveType(callBinding.getScope(), operand); - final SqlTypeFamily expectedFamily = operandTypes.get(i); - - if (expectedFamily == SqlTypeFamily.ANY) { - // ANY matches anything. This operand is all good; do nothing. - } else if (expectedFamily.getTypeNames().contains(operandType.getSqlTypeName())) { - // Operand came in with one of the expected types. - } else if (operandType.getSqlTypeName() == SqlTypeName.NULL || SqlUtil.isNullLiteral(operand, true)) { - // Null came in, check if operand is a nullable type. - if (!nullableOperands.contains(i)) { - return throwOrReturn( - throwOnFailure, - callBinding, - cb -> cb.getValidator().newValidationError(operand, Static.RESOURCE.nullIllegal()) - ); - } - } else { - return throwOrReturn( - throwOnFailure, - callBinding, - SqlCallBinding::newValidationSignatureError - ); - } - } - - return true; - } - - @Override - public SqlOperandCountRange getOperandCountRange() - { - return SqlOperandCountRanges.between(requiredOperands, operandTypes.size()); - } - - @Override - public String getAllowedSignatures(SqlOperator op, String opName) - { - final List operands = !operandNames.isEmpty() ? operandNames : operandTypes; - final StringBuilder ret = new StringBuilder(); - ret.append("'"); - ret.append(opName); - ret.append("("); - for (int i = 0; i < operands.size(); i++) { - if (i > 0) { - ret.append(", "); - } - if (i >= requiredOperands) { - ret.append("["); - } - ret.append("<").append(operands.get(i)).append(">"); - } - for (int i = requiredOperands; i < operands.size(); i++) { - ret.append("]"); - } - ret.append(")'"); - return ret.toString(); - } - - @Override - public Consistency getConsistency() - { - return Consistency.NONE; - } - - @Override - public boolean isOptional(int i) - { - return i + 1 > requiredOperands; - } - } - public static boolean throwOrReturn( final boolean throwOnFailure, final SqlCallBinding callBinding, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ComplexDecodeBase64OperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ComplexDecodeBase64OperatorConversion.java index 94b90ed9af8..51bf88c363a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ComplexDecodeBase64OperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ComplexDecodeBase64OperatorConversion.java @@ -23,7 +23,6 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.druid.java.util.common.StringUtils; @@ -52,13 +51,10 @@ public class ComplexDecodeBase64OperatorConversion implements SqlOperatorConvers private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder(StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME)) - .operandTypeChecker( - OperandTypes.sequence( - "'" + StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME) + "(typeName, base64)'", - OperandTypes.and(OperandTypes.family(SqlTypeFamily.STRING), OperandTypes.LITERAL), - OperandTypes.ANY - ) - ) + .operandNames("typeName", "base64") + .operandTypes(SqlTypeFamily.STRING, SqlTypeFamily.ANY) + .requiredOperandCount(2) + .literalOperands(0) .returnTypeInference(ARBITRARY_COMPLEX_RETURN_TYPE_INFERENCE) .functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/NestedDataOperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/NestedDataOperatorConversions.java index ffaa42b1b65..2ba07c126af 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/NestedDataOperatorConversions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/NestedDataOperatorConversions.java @@ -118,13 +118,10 @@ public class NestedDataOperatorConversions { private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("JSON_KEYS") - .operandTypeChecker( - OperandTypes.sequence( - "'JSON_KEYS(expr, path)'", - OperandTypes.ANY, - OperandTypes.and(OperandTypes.family(SqlTypeFamily.STRING), OperandTypes.LITERAL) - ) - ) + .operandNames("expr", "path") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.STRING) + .literalOperands(1) + .requiredOperandCount(2) .functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION) .returnTypeNullableArrayWithNullableElements(SqlTypeName.VARCHAR) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/QueryLookupOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/QueryLookupOperatorConversion.java index 21d4bea356c..4f266e7a0e2 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/QueryLookupOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/QueryLookupOperatorConversion.java @@ -31,7 +31,6 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; import org.apache.druid.query.lookup.RegisteredLookupExtractionFn; import org.apache.druid.segment.column.RowSignature; -import org.apache.druid.sql.calcite.expression.BasicOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; @@ -43,16 +42,10 @@ public class QueryLookupOperatorConversion implements SqlOperatorConversion { private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("LOOKUP") - .operandTypeChecker( - BasicOperandTypeChecker.builder() - .operandTypes( - SqlTypeFamily.CHARACTER, - SqlTypeFamily.CHARACTER, - SqlTypeFamily.CHARACTER - ) - .requiredOperandCount(2) - .literalOperands(2) - .build()) + .operandNames("expr", "lookupName", "replaceMissingValueWith") + .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) + .requiredOperandCount(2) + .literalOperands(2) .returnTypeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/convertlet/TimeInIntervalConvertletFactory.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/convertlet/TimeInIntervalConvertletFactory.java index b99043e5fcb..e3dcf71879a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/convertlet/TimeInIntervalConvertletFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/convertlet/TimeInIntervalConvertletFactory.java @@ -27,7 +27,6 @@ import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql2rel.SqlRexContext; @@ -53,13 +52,10 @@ public class TimeInIntervalConvertletFactory implements DruidConvertletFactory private static final SqlOperator OPERATOR = OperatorConversions .operatorBuilder(NAME) - .operandTypeChecker( - OperandTypes.sequence( - "'" + NAME + "(, )'", - OperandTypes.family(SqlTypeFamily.TIMESTAMP), - OperandTypes.and(OperandTypes.family(SqlTypeFamily.CHARACTER), OperandTypes.LITERAL) - ) - ) + .operandNames("timestamp", "interval") + .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER) + .requiredOperandCount(2) + .literalOperands(1) .returnTypeNonNull(SqlTypeName.BOOLEAN) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java index 315d6da4fd9..f65c98bc1a1 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java @@ -4454,6 +4454,48 @@ public class CalciteNestedDataQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testGroupByCastedRootKeysJsonPath() + { + cannotVectorize(); + testQuery( + "SELECT " + + "JSON_KEYS(nester, CAST('$.' AS VARCHAR)), " + + "SUM(cnt) " + + "FROM druid.nested GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(DATA_SOURCE) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + new ExpressionVirtualColumn( + "v0", + "json_keys(\"nester\",'$.')", + ColumnType.STRING_ARRAY, + queryFramework().macroTable() + ) + ) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "d0", ColumnType.STRING_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{null, 5L}, + new Object[]{"[\"array\",\"n\"]", 2L} + ), + RowSignature.builder() + .add("EXPR$0", ColumnType.STRING_ARRAY) + .add("EXPR$1", ColumnType.LONG) + .build() + ); + } + @Test public void testGroupByAllPaths() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index b67db5dce41..702d73d9f9d 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -642,10 +642,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest testQuery( "SELECT " - + "EARLIEST(cnt), EARLIEST(m1), EARLIEST(dim1, 10), " + + "EARLIEST(cnt), EARLIEST(m1), EARLIEST(dim1, 10), EARLIEST(dim1, CAST(10 AS INTEGER)), " + "EARLIEST(cnt + 1), EARLIEST(m1 + 1), EARLIEST(dim1 || CAST(cnt AS VARCHAR), 10), " + "EARLIEST_BY(cnt, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1, MILLIS_TO_TIMESTAMP(l1), 10), " - + "EARLIEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) " + + "EARLIEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) " + "FROM druid.numfoo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -677,7 +677,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), ImmutableList.of( - new Object[]{1L, 1.0f, "", 2L, 2.0f, "1", 1L, 3.0f, "2", 2L, 4.0f, "21"} + new Object[]{1L, 1.0f, "", "", 2L, 2.0f, "1", 1L, 3.0f, "2", 2L, 4.0f, "21", "21"} ) ); } @@ -752,10 +752,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest testQuery( "SELECT " - + "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), " + + "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), LATEST(dim1, CAST(10 AS INTEGER)), " + "LATEST(cnt + 1), LATEST(m1 + 1), LATEST(dim1 || CAST(cnt AS VARCHAR), 10), " + "LATEST_BY(cnt, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1, MILLIS_TO_TIMESTAMP(l1), 10), " - + "LATEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) " + + "LATEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), CAST(10 AS INTEGER)) " + "FROM druid.numfoo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -787,7 +787,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), ImmutableList.of( - new Object[]{1L, 6.0f, "abc", 2L, 7.0f, "abc1", 1L, 2.0f, "10.1", 2L, 3.0f, "10.11"} + new Object[]{1L, 6.0f, "abc", "abc", 2L, 7.0f, "abc1", 1L, 2.0f, "10.1", 2L, 3.0f, "10.11", "10.11"} ) ); } @@ -5612,6 +5612,28 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testCountStarWithTimeInCastedIntervalFilter() + { + testQuery( + "SELECT COUNT(*) FROM druid.foo " + + "WHERE TIME_IN_INTERVAL(__time, CAST('2000-01-01/P1Y' AS VARCHAR)) " + + "AND TIME_IN_INTERVAL(CURRENT_TIMESTAMP, '2000/3000') -- Optimized away: always true", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Intervals.of("2000-01-01/2001-01-01"))) + .granularity(Granularities.ALL) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{3L} + ) + ); + } + @Test public void testCountStarWithTimeInIntervalFilterLosAngeles() { @@ -5661,7 +5683,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest expected -> { expected.expect(CoreMatchers.instanceOf(DruidException.class)); expected.expect(ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString( - "Argument to function 'TIME_IN_INTERVAL' must be a literal (line [1], column [38])"))); + "Argument to function 'TIME_IN_INTERVAL' must be a literal (line [1], column [63])"))); } ); } @@ -14076,6 +14098,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } } + @Test public void testComplexDecodeAgg() { @@ -14109,6 +14132,43 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ) ); } + + @Test + public void testComplexDecodeAggWithCastedTypeName() + { + msqIncompatible(); + cannotVectorize(); + testQuery( + "SELECT " + + "APPROX_COUNT_DISTINCT_BUILTIN(COMPLEX_DECODE_BASE64(CAST('hyperUnique' AS VARCHAR),PARSE_JSON(TO_JSON_STRING(unique_dim1)))) " + + "FROM druid.foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + expressionVirtualColumn( + "v0", + "complex_decode_base64('hyperUnique',parse_json(to_json_string(\"unique_dim1\")))", + ColumnType.ofComplex("hyperUnique") + ) + ) + .aggregators( + new HyperUniquesAggregatorFactory( + "a0", + "v0", + false, + true + ) + ) + .build() + ), + ImmutableList.of( + new Object[]{6L} + ) + ); + } + @NotYetSupported @Test public void testOrderByAlongWithInternalScanQuery() diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java index a240a2c2986..5a0c52514f3 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java @@ -20,7 +20,6 @@ package org.apache.druid.sql.calcite.expression; import com.google.common.collect.ImmutableList; -import it.unimi.dsi.fastutil.ints.IntSets; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.runtime.CalciteContextException; import org.apache.calcite.runtime.Resources.ExInst; @@ -38,7 +37,6 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.planner.DruidTypeSystem; import org.junit.Assert; import org.junit.Rule; @@ -51,7 +49,6 @@ import org.mockito.Mockito; import org.mockito.stubbing.Answer; import java.util.ArrayList; -import java.util.Collections; import java.util.List; @RunWith(Enclosed.class) @@ -65,13 +62,13 @@ public class OperatorConversionsTest @Test public void testGetOperandCountRange() { - SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker( - Collections.emptyList(), - ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER), - 2, - IntSets.EMPTY_SET, - null - ); + SqlOperandTypeChecker typeChecker = DefaultOperandTypeChecker + .builder() + .operandNames() + .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER) + .requiredOperandCount(2) + .literalOperands() + .build(); SqlOperandCountRange countRange = typeChecker.getOperandCountRange(); Assert.assertEquals(2, countRange.getMin()); Assert.assertEquals(3, countRange.getMax()); @@ -80,13 +77,13 @@ public class OperatorConversionsTest @Test public void testIsOptional() { - SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker( - Collections.emptyList(), - ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER), - 2, - IntSets.EMPTY_SET, - null - ); + SqlOperandTypeChecker typeChecker = DefaultOperandTypeChecker + .builder() + .operandNames() + .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER) + .requiredOperandCount(2) + .literalOperands() + .build(); Assert.assertFalse(typeChecker.isOptional(0)); Assert.assertFalse(typeChecker.isOptional(1)); Assert.assertTrue(typeChecker.isOptional(2)); @@ -121,7 +118,7 @@ public class OperatorConversionsTest { SqlFunction function = OperatorConversions .operatorBuilder("testRequiredOperandsOnly") - .operandTypeChecker(BasicOperandTypeChecker.builder().operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE).requiredOperandCount(1).build()) + .operandTypeChecker(DefaultOperandTypeChecker.builder().operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE).requiredOperandCount(1).build()) .returnTypeNonNull(SqlTypeName.CHAR) .build(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();