From 2ea7177f15139071980a881f111f7dbd86287baa Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Wed, 1 Nov 2023 10:38:48 +0530 Subject: [PATCH] Allow casted literal values in SQL functions accepting literals (#15282) Functions that accept literals also allow casted literals. This shouldn't have an impact on the queries that the user writes. It enables the SQL functions to accept explicit cast, which is required with JDBC. --- codestyle/druid-forbidden-apis.txt | 1 + .../TDigestGenerateSketchSqlAggregator.java | 24 +- .../TDigestSketchQuantileSqlAggregator.java | 27 +-- .../sql/TDigestSketchSqlAggregatorTest.java | 70 ++++++ ...esSketchListArgBaseOperatorConversion.java | 7 +- .../sql/DoublesSketchSqlAggregatorTest.java | 36 +-- .../bloom/sql/BloomFilterSqlAggregator.java | 20 +- .../sql/BloomFilterSqlAggregatorTest.java | 8 +- ...BucketsHistogramQuantileSqlAggregator.java | 74 +++--- .../histogram/sql/QuantileSqlAggregator.java | 28 ++- ...etsHistogramQuantileSqlAggregatorTest.java | 82 ++++++- .../sql/QuantileSqlAggregatorTest.java | 27 ++- .../EarliestLatestAnySqlAggregator.java | 18 +- .../EarliestLatestBySqlAggregator.java | 22 +- ...er.java => DefaultOperandTypeChecker.java} | 136 ++++++----- .../expression/OperatorConversions.java | 217 +++--------------- ...ComplexDecodeBase64OperatorConversion.java | 12 +- .../NestedDataOperatorConversions.java | 11 +- .../QueryLookupOperatorConversion.java | 15 +- .../TimeInIntervalConvertletFactory.java | 12 +- .../calcite/CalciteNestedDataQueryTest.java | 42 ++++ .../druid/sql/calcite/CalciteQueryTest.java | 74 +++++- .../expression/OperatorConversionsTest.java | 33 ++- 23 files changed, 544 insertions(+), 452 deletions(-) rename sql/src/main/java/org/apache/druid/sql/calcite/expression/{BasicOperandTypeChecker.java => DefaultOperandTypeChecker.java} (62%) diff --git a/codestyle/druid-forbidden-apis.txt b/codestyle/druid-forbidden-apis.txt index a99654f1212..8da588f9d31 100644 --- a/codestyle/druid-forbidden-apis.txt +++ b/codestyle/druid-forbidden-apis.txt @@ -44,6 +44,7 @@ java.util.LinkedList @ Use ArrayList or ArrayDeque instead java.util.Random#() @ Use ThreadLocalRandom.current() or the constructor with a seed (the latter in tests only!) java.lang.Math#random() @ Use ThreadLocalRandom.current() java.util.regex.Pattern#matches(java.lang.String,java.lang.CharSequence) @ Use String.startsWith(), endsWith(), contains(), or compile and cache a Pattern explicitly +org.apache.calcite.sql.type.OperandTypes#LITERAL @ LITERAL type checker throws when literals with CAST are passed. Use org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker instead. org.apache.commons.io.FileUtils#getTempDirectory() @ Use org.junit.rules.TemporaryFolder for tests instead org.apache.commons.io.FileUtils#deleteDirectory(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils#deleteDirectory() org.apache.commons.io.FileUtils#forceMkdir(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils.mkdirp instead diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java index ebb6c7f4b14..cb9b95423a6 100644 --- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java +++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java @@ -25,10 +25,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorFactory; @@ -37,6 +37,7 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -133,8 +134,6 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE_WITH_COMPRESSION = "'" + NAME + "(column, compression)'"; - TDigestGenerateSketchSqlAggFunction() { super( @@ -143,16 +142,19 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.OTHER), null, - OperandTypes.or( - OperandTypes.ANY, - OperandTypes.and( - OperandTypes.sequence(SIGNATURE_WITH_COMPRESSION, OperandTypes.ANY, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) - ) - ), + // Validation for signatures like 'TDIGEST_GENERATE_SKETCH(column)' and + // 'TDIGEST_GENERATE_SKETCH(column, compression)' + DefaultOperandTypeChecker + .builder() + .operandNames("column", "compression") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + .requiredOperandCount(1) + .literalOperands(1) + .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, - false + false, + Optionality.FORBIDDEN ); } } diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java index ee63444f6d7..7cc0bdfe8c5 100644 --- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java +++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java @@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; @@ -40,6 +40,7 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -158,9 +159,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator private static class TDigestSketchQuantileSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE1 = "'" + NAME + "(column, quantile)'"; - private static final String SIGNATURE2 = "'" + NAME + "(column, quantile, compression)'"; - TDigestSketchQuantileSqlAggFunction() { super( @@ -169,19 +167,18 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.DOUBLE), null, - OperandTypes.or( - OperandTypes.and( - OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) - ), - OperandTypes.and( - OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC) - ) - ), + // Accounts for both 'TDIGEST_QUANTILE(column, quantile)' and 'TDIGEST_QUANTILE(column, quantile, compression)' + DefaultOperandTypeChecker + .builder() + .operandNames("column", "quantile", "compression") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC) + .literalOperands(1, 2) + .requiredOperandCount(2) + .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, - false + false, + Optionality.FORBIDDEN ); } } diff --git a/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java b/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java index c9dba876276..6193b82cbdd 100644 --- a/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java +++ b/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java @@ -138,6 +138,76 @@ public class TDigestSketchSqlAggregatorTest extends BaseCalciteQueryTest ); } + @Test + public void testCastedQuantileAndCompressionParamForTDigestQuantileAgg() + { + cannotVectorize(); + testQuery( + "SELECT\n" + + "TDIGEST_QUANTILE(m1, CAST(0.0 AS DOUBLE)), " + + "TDIGEST_QUANTILE(m1, CAST(0.5 AS FLOAT), CAST(200 AS INTEGER)), " + + "TDIGEST_QUANTILE(m1, CAST(1.0 AS DOUBLE), 300)\n" + + "FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(Granularities.ALL) + .aggregators(ImmutableList.of( + new TDigestSketchAggregatorFactory("a0:agg", "m1", + TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION + ), + new TDigestSketchAggregatorFactory("a1:agg", "m1", + 200 + ), + new TDigestSketchAggregatorFactory("a2:agg", "m1", + 300 + ) + )) + .postAggregators( + new TDigestSketchToQuantilePostAggregator("a0", makeFieldAccessPostAgg("a0:agg"), 0.0f), + new TDigestSketchToQuantilePostAggregator("a1", makeFieldAccessPostAgg("a1:agg"), 0.5f), + new TDigestSketchToQuantilePostAggregator("a2", makeFieldAccessPostAgg("a2:agg"), 1.0f) + ) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ResultMatchMode.EQUALS_EPS, + ImmutableList.of( + new Object[]{1.0, 3.5, 6.0} + ) + ); + } + + @Test + public void testComputingSketchOnNumericValuesWithCastedCompressionParameter() + { + cannotVectorize(); + + testQuery( + "SELECT\n" + + "TDIGEST_GENERATE_SKETCH(m1, CAST(200 AS INTEGER))" + + "FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(Granularities.ALL) + .aggregators(ImmutableList.of( + new TDigestSketchAggregatorFactory("a0:agg", "m1", 200) + )) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ResultMatchMode.EQUALS_EPS, + ImmutableList.of( + new String[]{ + "\"AAAAAT/wAAAAAAAAQBgAAAAAAABAaQAAAAAAAAAAAAY/8AAAAAAAAD/wAAAAAAAAP/AAAAAAAABAAAAAAAAAAD/wAAAAAAAAQAgAAAAAAAA/8AAAAAAAAEAQAAAAAAAAP/AAAAAAAABAFAAAAAAAAD/wAAAAAAAAQBgAAAAAAAA=\"" + } + ) + ); + } + @Test public void testComputingSketchOnCastedString() { diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchListArgBaseOperatorConversion.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchListArgBaseOperatorConversion.java index a83d937f680..337586b1f94 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchListArgBaseOperatorConversion.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchListArgBaseOperatorConversion.java @@ -40,7 +40,6 @@ import org.apache.calcite.util.Static; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.segment.column.RowSignature; -import org.apache.druid.sql.calcite.expression.BasicOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.PostAggregatorVisitor; @@ -143,8 +142,8 @@ public abstract class DoublesSketchListArgBaseOperatorConversion implements SqlO final RelDataType operandType = callBinding.getValidator().deriveType(callBinding.getScope(), operand); // Verify that 'operand' is a literal number. - if (!SqlUtil.isLiteral(operand)) { - return BasicOperandTypeChecker.throwOrReturn( + if (!SqlUtil.isLiteral(operand, true)) { + return OperatorConversions.throwOrReturn( throwOnFailure, callBinding, cb -> cb.getValidator() @@ -156,7 +155,7 @@ public abstract class DoublesSketchListArgBaseOperatorConversion implements SqlO } if (!SqlTypeFamily.NUMERIC.contains(operandType)) { - return BasicOperandTypeChecker.throwOrReturn( + return OperatorConversions.throwOrReturn( throwOnFailure, callBinding, SqlCallBinding::newValidationSignatureError diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java index f85225d107d..b39d3441b65 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java @@ -495,6 +495,7 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest + " DS_GET_QUANTILE(DS_QUANTILES_SKETCH(cnt + 123), 0.5) + 1000,\n" + " ABS(DS_GET_QUANTILE(DS_QUANTILES_SKETCH(cnt), 0.5)),\n" + " DS_GET_QUANTILES(DS_QUANTILES_SKETCH(cnt), 0.5, 0.8),\n" + + " DS_GET_QUANTILES(DS_QUANTILES_SKETCH(cnt), CAST(0.5 AS DOUBLE), CAST(0.8 AS DOUBLE)),\n" + " DS_HISTOGRAM(DS_QUANTILES_SKETCH(cnt), 0.2, 0.6),\n" + " DS_RANK(DS_QUANTILES_SKETCH(cnt), 3),\n" + " DS_CDF(DS_QUANTILES_SKETCH(cnt), 0.2, 0.6),\n" @@ -588,41 +589,49 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest ), new double[]{0.5d, 0.8d} ), - new DoublesSketchToHistogramPostAggregator( + new DoublesSketchToQuantilesPostAggregator( "p13", new FieldAccessPostAggregator( "p12", "a2:agg" ), + new double[]{0.5d, 0.8d} + ), + new DoublesSketchToHistogramPostAggregator( + "p15", + new FieldAccessPostAggregator( + "p14", + "a2:agg" + ), new double[]{0.2d, 0.6d}, null ), new DoublesSketchToRankPostAggregator( - "p15", - new FieldAccessPostAggregator( - "p14", - "a2:agg" - ), - 3.0d - ), - new DoublesSketchToCDFPostAggregator( "p17", new FieldAccessPostAggregator( "p16", "a2:agg" ), - new double[]{0.2d, 0.6d} + 3.0d ), - new DoublesSketchToStringPostAggregator( + new DoublesSketchToCDFPostAggregator( "p19", new FieldAccessPostAggregator( "p18", "a2:agg" + ), + new double[]{0.2d, 0.6d} + ), + new DoublesSketchToStringPostAggregator( + "p21", + new FieldAccessPostAggregator( + "p20", + "a2:agg" ) ), new ExpressionPostAggregator( - "p20", - "replace(replace(\"p19\",'HeapCompactDoublesSketch','HeapUpdateDoublesSketch')," + "p22", + "replace(replace(\"p21\",'HeapCompactDoublesSketch','HeapUpdateDoublesSketch')," + "'Combined Buffer Capacity : 6'," + "'Combined Buffer Capacity : 8')", null, @@ -640,6 +649,7 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest 1124.0d, 1.0d, "[1.0,1.0]", + "[1.0,1.0]", "[0.0,0.0,6.0]", 1.0d, "[0.0,0.0,1.0]", diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java index 6a1ca49067e..cb73d94ef7e 100644 --- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java +++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java @@ -25,10 +25,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.bloom.BloomFilterAggregatorFactory; @@ -38,6 +38,7 @@ import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; @@ -168,8 +169,6 @@ public class BloomFilterSqlAggregator implements SqlAggregator private static class BloomFilterSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE1 = "'" + NAME + "(column, maxNumEntries)'"; - BloomFilterSqlAggFunction() { super( @@ -178,13 +177,18 @@ public class BloomFilterSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.OTHER), null, - OperandTypes.and( - OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) - ), + // Allow signatures like 'BLOOM_FILTER(column, maxNumEntries)' + DefaultOperandTypeChecker + .builder() + .operandNames("column", "maxNumEntries") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + .literalOperands(1) + .requiredOperandCount(2) + .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, - false + false, + Optionality.FORBIDDEN ); } } diff --git a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java index 1c77f0986e1..eff2039ffa8 100644 --- a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java +++ b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java @@ -125,7 +125,8 @@ public class BloomFilterSqlAggregatorTest extends BaseCalciteQueryTest testQuery( "SELECT\n" - + "BLOOM_FILTER(dim1, 1000)\n" + + "BLOOM_FILTER(dim1, 1000),\n" + + "BLOOM_FILTER(dim1, CAST(1000 AS INTEGER))\n" + "FROM numfoo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -145,7 +146,10 @@ public class BloomFilterSqlAggregatorTest extends BaseCalciteQueryTest .build() ), ImmutableList.of( - new Object[]{queryFramework().queryJsonMapper().writeValueAsString(expected1)} + new Object[]{ + queryFramework().queryJsonMapper().writeValueAsString(expected1), + queryFramework().queryJsonMapper().writeValueAsString(expected1) + } ) ); } diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java index fdc61796c4d..eceaa7b8ad2 100644 --- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java @@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogram; @@ -39,6 +39,7 @@ import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -221,15 +222,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator private static class FixedBucketsHistogramQuantileSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE1 = - "'" - + NAME - + "(column, probability, numBuckets, lowerLimit, upperLimit)'"; - private static final String SIGNATURE2 = - "'" - + NAME - + "(column, probability, numBuckets, lowerLimit, upperLimit, outlierHandlingMode)'"; - FixedBucketsHistogramQuantileSqlAggFunction() { super( @@ -238,47 +230,33 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.DOUBLE), null, - OperandTypes.or( - OperandTypes.and( - OperandTypes.sequence( - SIGNATURE1, - OperandTypes.ANY, - OperandTypes.LITERAL, - OperandTypes.LITERAL, - OperandTypes.LITERAL, - OperandTypes.LITERAL - ), - OperandTypes.family( - SqlTypeFamily.ANY, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC - ) - ), - OperandTypes.and( - OperandTypes.sequence( - SIGNATURE2, - OperandTypes.ANY, - OperandTypes.LITERAL, - OperandTypes.LITERAL, - OperandTypes.LITERAL, - OperandTypes.LITERAL, - OperandTypes.LITERAL - ), - OperandTypes.family( - SqlTypeFamily.ANY, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.NUMERIC, - SqlTypeFamily.STRING - ) + // Allows signatures like 'APPROX_QUANTILE_FIXED_BUCKETS(column, probability, numBuckets, lowerLimit, upperLimit)' + // and 'APPROX_QUANTILE_FIXED_BUCKETS(column, probability, numBuckets, lowerLimit, upperLimit, outlierHandlingMode)' + DefaultOperandTypeChecker + .builder() + .operandNames( + "column", + "probability", + "numBuckets", + "lowerLimit", + "upperLimit", + "outlierHandlingMode" ) - ), + .operandTypes( + SqlTypeFamily.ANY, + SqlTypeFamily.NUMERIC, + SqlTypeFamily.NUMERIC, + SqlTypeFamily.NUMERIC, + SqlTypeFamily.NUMERIC, + SqlTypeFamily.STRING + ) + .literalOperands(1, 2, 3, 4, 5) + .requiredOperandCount(5) + .build(), SqlFunctionCategory.NUMERIC, false, - false + false, + Optionality.FORBIDDEN ); } } diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java index a3fe8dc5458..41df080147b 100644 --- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java @@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.histogram.ApproximateHistogram; @@ -41,6 +41,7 @@ import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -196,9 +197,6 @@ public class QuantileSqlAggregator implements SqlAggregator private static class QuantileSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE1 = "'" + NAME + "(column, probability)'"; - private static final String SIGNATURE2 = "'" + NAME + "(column, probability, resolution)'"; - QuantileSqlAggFunction() { super( @@ -207,19 +205,19 @@ public class QuantileSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.DOUBLE), null, - OperandTypes.or( - OperandTypes.and( - OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) - ), - OperandTypes.and( - OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC) - ) - ), + // Checks for signatures like 'APPROX_QUANTILE(column, probability)' and + // 'APPROX_QUANTILE(column, probability, resolution)' + DefaultOperandTypeChecker + .builder() + .operandNames("column", "probability", "resolution") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC) + .requiredOperandCount(2) + .literalOperands(1, 2) + .build(), SqlFunctionCategory.NUMERIC, false, - false + false, + Optionality.FORBIDDEN ); } } diff --git a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java index d0a12ab2f2d..c54231c86d8 100644 --- a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java +++ b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java @@ -128,6 +128,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ 6.494999885559082, 5.497499942779541, 6.499499797821045, + 6.499499797821045, + 6.499499797821045, + 6.499499797821045, 1.25 } ); @@ -142,6 +145,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.99, 20, 0.0, 10.0) FILTER(WHERE dim1 = 'abc'),\n" + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0) FILTER(WHERE dim1 <> 'abc'),\n" + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0) FILTER(WHERE dim1 = 'abc'),\n" + + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'ignore') FILTER(WHERE dim1 = 'abc'),\n" + + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'clip') FILTER(WHERE dim1 = 'abc'),\n" + + "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'overflow') FILTER(WHERE dim1 = 'abc'),\n" + "APPROX_QUANTILE_FIXED_BUCKETS(cnt, 0.5, 20, 0.0, 10.0)\n" + "FROM foo", ImmutableList.of( @@ -200,8 +206,32 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ ), not(equality("dim1", "abc", ColumnType.STRING)) ), + new FilteredAggregatorFactory( + new FixedBucketsHistogramAggregatorFactory( + "a9:agg", + "m1", + 20, + 0.0d, + 10.0d, + FixedBucketsHistogram.OutlierHandlingMode.CLIP, + false + ), + equality("dim1", "abc", ColumnType.STRING) + ), + new FilteredAggregatorFactory( + new FixedBucketsHistogramAggregatorFactory( + "a10:agg", + "m1", + 20, + 0.0d, + 10.0d, + FixedBucketsHistogram.OutlierHandlingMode.OVERFLOW, + false + ), + equality("dim1", "abc", ColumnType.STRING) + ), new FixedBucketsHistogramAggregatorFactory( - "a8:agg", + "a11:agg", "cnt", 20, 0.0d, @@ -219,7 +249,55 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ new QuantilePostAggregator("a5", "a5:agg", 0.99f), new QuantilePostAggregator("a6", "a6:agg", 0.999f), new QuantilePostAggregator("a7", "a5:agg", 0.999f), - new QuantilePostAggregator("a8", "a8:agg", 0.50f) + new QuantilePostAggregator("a8", "a5:agg", 0.999f), + new QuantilePostAggregator("a9", "a9:agg", 0.999f), + new QuantilePostAggregator("a10", "a10:agg", 0.999f), + new QuantilePostAggregator("a11", "a11:agg", 0.50f) + ) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + expectedResults + ); + } + + @Test + public void testQuantileWithCastedLiteralArguments() + { + final List expectedResults = ImmutableList.of(new Object[]{6.499499797821045}); + testQuery( + "SELECT\n" + + "APPROX_QUANTILE_FIXED_BUCKETS(" + + "m1, " + + "CAST(0.999 AS DOUBLE), " + + "CAST(20 AS INTEGER), " + + "CAST(0.0 AS DOUBLE), " + + "CAST(10.0 AS DOUBLE), " + + "CAST('overflow' AS VARCHAR)" + + ") " + + "FILTER(WHERE dim1 = 'abc')\n" + + "FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(Granularities.ALL) + .aggregators(ImmutableList.of( + new FilteredAggregatorFactory( + new FixedBucketsHistogramAggregatorFactory( + "a0:agg", + "m1", + 20, + 0.0d, + 10.0d, + FixedBucketsHistogram.OutlierHandlingMode.OVERFLOW, + false + ), + equality("dim1", "abc", ColumnType.STRING) + ) + )) + .postAggregators( + new QuantilePostAggregator("a0", "a0:agg", 0.999f) ) .context(QUERY_CONTEXT_DEFAULT) .build() diff --git a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java index 9e36def53cb..6bcd26d23c1 100644 --- a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java +++ b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java @@ -121,6 +121,7 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest "SELECT\n" + "APPROX_QUANTILE(m1, 0.01),\n" + "APPROX_QUANTILE(m1, 0.5, 50),\n" + + "APPROX_QUANTILE(m1, CAST(0.5 AS DOUBLE), CAST(50 AS INTEGER)),\n" + "APPROX_QUANTILE(m1, 0.98, 200),\n" + "APPROX_QUANTILE(m1, 0.99),\n" + "APPROX_QUANTILE(m1 * 2, 0.97),\n" @@ -144,28 +145,29 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest ) .aggregators(ImmutableList.of( new ApproximateHistogramAggregatorFactory("a0:agg", "m1", null, null, null, null, false), - new ApproximateHistogramAggregatorFactory("a2:agg", "m1", 200, null, null, null, false), - new ApproximateHistogramAggregatorFactory("a4:agg", "v0", null, null, null, null, false), + new ApproximateHistogramAggregatorFactory("a3:agg", "m1", 200, null, null, null, false), + new ApproximateHistogramAggregatorFactory("a5:agg", "v0", null, null, null, null, false), new FilteredAggregatorFactory( - new ApproximateHistogramAggregatorFactory("a5:agg", "m1", null, null, null, null, false), + new ApproximateHistogramAggregatorFactory("a6:agg", "m1", null, null, null, null, false), equality("dim1", "abc", ColumnType.STRING) ), new FilteredAggregatorFactory( - new ApproximateHistogramAggregatorFactory("a6:agg", "m1", null, null, null, null, false), + new ApproximateHistogramAggregatorFactory("a7:agg", "m1", null, null, null, null, false), not(equality("dim1", "abc", ColumnType.STRING)) ), - new ApproximateHistogramAggregatorFactory("a8:agg", "cnt", null, null, null, null, false) + new ApproximateHistogramAggregatorFactory("a9:agg", "cnt", null, null, null, null, false) )) .postAggregators( new QuantilePostAggregator("a0", "a0:agg", 0.01f), new QuantilePostAggregator("a1", "a0:agg", 0.50f), - new QuantilePostAggregator("a2", "a2:agg", 0.98f), - new QuantilePostAggregator("a3", "a0:agg", 0.99f), - new QuantilePostAggregator("a4", "a4:agg", 0.97f), - new QuantilePostAggregator("a5", "a5:agg", 0.99f), - new QuantilePostAggregator("a6", "a6:agg", 0.999f), - new QuantilePostAggregator("a7", "a5:agg", 0.999f), - new QuantilePostAggregator("a8", "a8:agg", 0.50f) + new QuantilePostAggregator("a2", "a0:agg", 0.50f), + new QuantilePostAggregator("a3", "a3:agg", 0.98f), + new QuantilePostAggregator("a4", "a0:agg", 0.99f), + new QuantilePostAggregator("a5", "a5:agg", 0.97f), + new QuantilePostAggregator("a6", "a6:agg", 0.99f), + new QuantilePostAggregator("a7", "a7:agg", 0.999f), + new QuantilePostAggregator("a8", "a6:agg", 0.999f), + new QuantilePostAggregator("a9", "a9:agg", 0.50f) ) .context(QUERY_CONTEXT_DEFAULT) .build() @@ -174,6 +176,7 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest new Object[]{ 1.0, 3.0, + 3.0, 5.880000114440918, 5.940000057220459, 11.640000343322754, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java index 2e031616027..efa3a9e7e32 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java @@ -33,8 +33,8 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.InferTypes; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.sql.util.SqlVisitor; @@ -61,6 +61,7 @@ import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; @@ -369,14 +370,13 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, EARLIEST_LATEST_ARG0_RETURN_TYPE_INFERENCE, InferTypes.RETURN_TYPE, - OperandTypes.or( - OperandTypes.ANY, - OperandTypes.sequence( - "'" + aggregatorType.name() + "(expr, maxBytesPerString)'", - OperandTypes.ANY, - OperandTypes.and(OperandTypes.NUMERIC, OperandTypes.LITERAL) - ) - ), + DefaultOperandTypeChecker + .builder() + .operandNames("expr", "maxBytesPerString") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + .requiredOperandCount(1) + .literalOperands(1) + .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, false, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java index c12be459cf5..03e23503a81 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java @@ -26,7 +26,6 @@ import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.InferTypes; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.util.Optionality; @@ -38,6 +37,7 @@ import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregat import org.apache.druid.segment.column.ColumnType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; @@ -176,19 +176,13 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, EARLIEST_LATEST_ARG0_RETURN_TYPE_INFERENCE, InferTypes.RETURN_TYPE, - OperandTypes.or( - OperandTypes.sequence( - "'" + StringUtils.format("%s_BY", aggregatorType.name()) + "(expr, timeColumn)'", - OperandTypes.ANY, - OperandTypes.family(SqlTypeFamily.TIMESTAMP) - ), - OperandTypes.sequence( - "'" + StringUtils.format("%s_BY", aggregatorType.name()) + "(expr, timeColumn, maxBytesPerString)'", - OperandTypes.ANY, - OperandTypes.family(SqlTypeFamily.TIMESTAMP), - OperandTypes.and(OperandTypes.NUMERIC, OperandTypes.LITERAL) - ) - ), + DefaultOperandTypeChecker + .builder() + .operandNames("expr", "timeColumn", "maxBytesPerString") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.NUMERIC) + .requiredOperandCount(2) + .literalOperands(2) + .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, false, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/BasicOperandTypeChecker.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java similarity index 62% rename from sql/src/main/java/org/apache/druid/sql/calcite/expression/BasicOperandTypeChecker.java rename to sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java index 177d4447f4e..f43fde3a935 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/BasicOperandTypeChecker.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java @@ -24,7 +24,6 @@ import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSets; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.runtime.CalciteException; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperandCountRange; @@ -35,35 +34,47 @@ import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.Static; +import org.apache.druid.java.util.common.ISE; import javax.annotation.Nullable; import java.util.Arrays; +import java.util.Collections; import java.util.List; -import java.util.function.Function; import java.util.stream.IntStream; /** - * Operand type checker that is used in simple situations: there are a particular number of operands, with + * Operand type checker that is used in 'simple' situations: there are a particular number of operands, with * particular types, some of which may be optional or nullable, and some of which may be required to be literals. */ -public class BasicOperandTypeChecker implements SqlOperandTypeChecker +public class DefaultOperandTypeChecker implements SqlOperandTypeChecker { + /** + * Operand names for {@link #getAllowedSignatures(SqlOperator, String)}. May be empty, in which case the + * {@link #operandTypes} are used instead. + */ + private final List operandNames; private final List operandTypes; private final int requiredOperands; - private final IntSet nullOperands; + private final IntSet nullableOperands; private final IntSet literalOperands; - BasicOperandTypeChecker( + private DefaultOperandTypeChecker( + final List operandNames, final List operandTypes, final int requiredOperands, - final IntSet nullOperands, + final IntSet nullableOperands, @Nullable final int[] literalOperands ) { Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0); + this.operandNames = Preconditions.checkNotNull(operandNames, "operandNames"); this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes"); this.requiredOperands = requiredOperands; - this.nullOperands = Preconditions.checkNotNull(nullOperands, "nullOperands"); + this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands"); + + if (!operandNames.isEmpty() && operandNames.size() != operandTypes.size()) { + throw new ISE("Operand name count[%s] and type count[%s] must match", operandNames.size(), operandTypes.size()); + } if (literalOperands == null) { this.literalOperands = IntSets.EMPTY_SET; @@ -78,19 +89,6 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker return new Builder(); } - public static boolean throwOrReturn( - final boolean throwOnFailure, - final SqlCallBinding callBinding, - final Function exceptionMapper - ) - { - if (throwOnFailure) { - throw exceptionMapper.apply(callBinding); - } else { - return false; - } - } - @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { @@ -98,9 +96,9 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker final SqlNode operand = callBinding.operands().get(i); if (literalOperands.contains(i)) { - // Verify that 'operand' is a literal. - if (!SqlUtil.isLiteral(operand)) { - return throwOrReturn( + // Verify that 'operand' is a literal. Allow CAST, since we can reduce these away later. + if (!SqlUtil.isLiteral(operand, true)) { + return OperatorConversions.throwOrReturn( throwOnFailure, callBinding, cb -> cb.getValidator() @@ -121,15 +119,15 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker // Operand came in with one of the expected types. } else if (operandType.getSqlTypeName() == SqlTypeName.NULL || SqlUtil.isNullLiteral(operand, true)) { // Null came in, check if operand is a nullable type. - if (!nullOperands.contains(i)) { - return throwOrReturn( + if (!nullableOperands.contains(i)) { + return OperatorConversions.throwOrReturn( throwOnFailure, callBinding, cb -> cb.getValidator().newValidationError(operand, Static.RESOURCE.nullIllegal()) ); } } else { - return throwOrReturn( + return OperatorConversions.throwOrReturn( throwOnFailure, callBinding, SqlCallBinding::newValidationSignatureError @@ -149,7 +147,25 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker @Override public String getAllowedSignatures(SqlOperator op, String opName) { - return SqlUtil.getAliasedSignature(op, opName, operandTypes); + final List operands = !operandNames.isEmpty() ? operandNames : operandTypes; + final StringBuilder ret = new StringBuilder(); + ret.append("'"); + ret.append(opName); + ret.append("("); + for (int i = 0; i < operands.size(); i++) { + if (i > 0) { + ret.append(", "); + } + if (i >= requiredOperands) { + ret.append("["); + } + ret.append("<").append(operands.get(i)).append(">"); + } + for (int i = requiredOperands; i < operands.size(); i++) { + ret.append("]"); + } + ret.append(")'"); + return ret.toString(); } @Override @@ -166,64 +182,70 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker public static class Builder { + private List operandNames = Collections.emptyList(); private List operandTypes; - private Integer requiredOperandCount = null; - private int[] literalOperands = null; - /** - * Signifies that a function accepts operands of type family given by {@param operandTypes}. - */ + @Nullable + private Integer requiredOperandCount; + private int[] literalOperands; + + private Builder() + { + } + + public Builder operandNames(final String... operandNames) + { + this.operandNames = Arrays.asList(operandNames); + return this; + } + + public Builder operandNames(final List operandNames) + { + this.operandNames = operandNames; + return this; + } + public Builder operandTypes(final SqlTypeFamily... operandTypes) { this.operandTypes = Arrays.asList(operandTypes); return this; } - /** - * Signifies that a function accepts operands of type family given by {@param operandTypes}. - */ public Builder operandTypes(final List operandTypes) { this.operandTypes = operandTypes; return this; } - /** - * Signifies that the first {@code requiredOperands} operands are required, and all later operands are optional. - * - * Required operands are not allowed to be null. Optional operands can either be skipped or explicitly provided as - * literal NULLs. For example, if {@code requiredOperands == 1}, then {@code F(x, NULL)} and {@code F(x)} are both - * accepted, and {@code x} must not be null. - */ - public Builder requiredOperandCount(final int requiredOperandCount) + public Builder requiredOperandCount(Integer requiredOperandCount) { this.requiredOperandCount = requiredOperandCount; return this; } - /** - * Signifies that the operands at positions given by {@code literalOperands} must be literals. - */ public Builder literalOperands(final int... literalOperands) { this.literalOperands = literalOperands; return this; } - public BasicOperandTypeChecker build() + public DefaultOperandTypeChecker build() { - // Create "nullableOperands" set including all optional arguments. - final IntSet nullableOperands = new IntArraySet(); - if (requiredOperandCount != null) { - IntStream.range(requiredOperandCount, operandTypes.size()).forEach(nullableOperands::add); - } - - return new BasicOperandTypeChecker( + int computedRequiredOperandCount = requiredOperandCount == null ? operandTypes.size() : requiredOperandCount; + return new DefaultOperandTypeChecker( + operandNames, operandTypes, - requiredOperandCount == null ? operandTypes.size() : requiredOperandCount, - nullableOperands, + computedRequiredOperandCount, + DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size()), literalOperands ); } } + + public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount) + { + final IntSet nullableOperands = new IntArraySet(); + IntStream.range(requiredOperandCount, totalOperandCount).forEach(nullableOperands::add); + return nullableOperands; + } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java index e8a2b796a26..39137b77aeb 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java @@ -19,13 +19,11 @@ package org.apache.druid.sql.calcite.expression; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntSet; -import it.unimi.dsi.fastutil.ints.IntSets; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; @@ -37,13 +35,9 @@ import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.SqlUtil; import org.apache.calcite.sql.type.BasicSqlType; import org.apache.calcite.sql.type.ReturnTypes; -import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; @@ -51,7 +45,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.calcite.util.Optionality; -import org.apache.calcite.util.Static; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; @@ -69,7 +63,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.function.Function; -import java.util.stream.IntStream; /** * Utilities for assisting in writing {@link SqlOperatorConversion} implementations. @@ -334,6 +327,13 @@ public class OperatorConversions return new AggregatorBuilder(name); } + /** + * Helps in creating the operator builder along with validations and type inference for simple operator conversions. + * + * The type checker for this operator conversion can be either supplied manually, or an instance of + * {@link DefaultOperandTypeChecker} will be used if the user passes in {@link #operandTypes} and other optional + * parameters. Exactly one of them must be supplied to the builder. + */ public static class OperatorBuilder { protected final String name; @@ -491,8 +491,8 @@ public class OperatorConversions } /** - * Equivalent to calling {@link BasicOperandTypeChecker.Builder#operandTypes(SqlTypeFamily...)}; leads to using a - * {@link BasicOperandTypeChecker} as our operand type checker. + * Equivalent to calling {@link DefaultOperandTypeChecker.Builder#operandTypes(SqlTypeFamily...)}; leads to using a + * {@link DefaultOperandTypeChecker} as our operand type checker. * * May be used in conjunction with {@link #requiredOperandCount(int)} and {@link #literalOperands(int...)} in order * to further refine operand checking logic. @@ -506,12 +506,9 @@ public class OperatorConversions } /** - * Equivalent to calling {@link BasicOperandTypeChecker.Builder#requiredOperandCount(int)}; leads to using a - * {@link BasicOperandTypeChecker} as our operand type checker. - * - * Not compatible with {@link #operandTypeChecker(SqlOperandTypeChecker)}. + * Equivalent to calling {@link DefaultOperandTypeChecker.Builder#requiredOperandCount(Integer)}; leads to using a + * {@link DefaultOperandTypeChecker} as our operand type checker. */ - @Deprecated public OperatorBuilder requiredOperandCount(final int requiredOperandCount) { this.requiredOperandCount = requiredOperandCount; @@ -531,8 +528,8 @@ public class OperatorConversions } /** - * Equivalent to calling {@link BasicOperandTypeChecker.Builder#literalOperands(int...)}; leads to using a - * {@link BasicOperandTypeChecker} as our operand type checker. + * Equivalent to calling {@link DefaultOperandTypeChecker.Builder#literalOperands(int...)}; leads to using a + * {@link DefaultOperandTypeChecker} as our operand type checker. * * Not compatible with {@link #operandTypeChecker(SqlOperandTypeChecker)}. */ @@ -554,37 +551,30 @@ public class OperatorConversions @SuppressWarnings("unchecked") public T build() { - final IntSet nullableOperands = buildNullableOperands(); return (T) new SqlFunction( name, kind, Preconditions.checkNotNull(returnTypeInference, "returnTypeInference"), - buildOperandTypeInference(nullableOperands), - buildOperandTypeChecker(nullableOperands), + buildOperandTypeInference(), + buildOperandTypeChecker(), functionCategory ); } - protected IntSet buildNullableOperands() - { - // Create "nullableOperands" set including all optional arguments. - final IntSet nullableOperands = new IntArraySet(); - if (requiredOperandCount != null) { - IntStream.range(requiredOperandCount, operandTypes.size()).forEach(nullableOperands::add); - } - return nullableOperands; - } - - protected SqlOperandTypeChecker buildOperandTypeChecker(final IntSet nullableOperands) + protected SqlOperandTypeChecker buildOperandTypeChecker() { if (operandTypeChecker == null) { - return new DefaultOperandTypeChecker( - operandNames, - operandTypes, - requiredOperandCount == null ? operandTypes.size() : requiredOperandCount, - nullableOperands, - literalOperands - ); + if (operandTypes == null) { + throw DruidException.defensive( + "'operandTypes' must be non null if 'operandTypeChecker' is not passed to the operator conversion."); + } + return DefaultOperandTypeChecker + .builder() + .operandNames(operandNames) + .operandTypes(operandTypes) + .requiredOperandCount(requiredOperandCount == null ? operandTypes.size() : requiredOperandCount) + .literalOperands(literalOperands) + .build(); } else if (operandNames.isEmpty() && operandTypes == null && requiredOperandCount == null @@ -598,8 +588,11 @@ public class OperatorConversions } } - protected SqlOperandTypeInference buildOperandTypeInference(final IntSet nullableOperands) + protected SqlOperandTypeInference buildOperandTypeInference() { + final IntSet nullableOperands = requiredOperandCount == null + ? new IntArraySet() + : DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size()); if (operandTypeInference == null) { SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands); return (callBinding, returnType, types) -> { @@ -634,9 +627,8 @@ public class OperatorConversions @Override public SqlAggFunction build() { - final IntSet nullableOperands = buildNullableOperands(); - final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference(nullableOperands); - final SqlOperandTypeChecker operandTypeChecker = buildOperandTypeChecker(nullableOperands); + final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference(); + final SqlOperandTypeChecker operandTypeChecker = buildOperandTypeChecker(); class DruidSqlAggFunction extends SqlAggFunction { @@ -735,147 +727,6 @@ public class OperatorConversions } } - /** - * Operand type checker that is used in 'simple' situations: there are a particular number of operands, with - * particular types, some of which may be optional or nullable, and some of which may be required to be literals. - */ - @VisibleForTesting - static class DefaultOperandTypeChecker implements SqlOperandTypeChecker - { - /** - * Operand names for {@link #getAllowedSignatures(SqlOperator, String)}. May be empty, in which case the - * {@link #operandTypes} are used instead. - */ - private final List operandNames; - private final List operandTypes; - private final int requiredOperands; - private final IntSet nullableOperands; - private final IntSet literalOperands; - - public int getNumberOfLiteralOperands() - { - return literalOperands.size(); - } - - @VisibleForTesting - DefaultOperandTypeChecker( - final List operandNames, - final List operandTypes, - final int requiredOperands, - final IntSet nullableOperands, - @Nullable final int[] literalOperands - ) - { - Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0); - this.operandNames = Preconditions.checkNotNull(operandNames, "operandNames"); - this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes"); - this.requiredOperands = requiredOperands; - this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands"); - - if (!operandNames.isEmpty() && operandNames.size() != operandTypes.size()) { - throw new ISE("Operand name count[%s] and type count[%s] must match", operandNames.size(), operandTypes.size()); - } - - if (literalOperands == null) { - this.literalOperands = IntSets.EMPTY_SET; - } else { - this.literalOperands = new IntArraySet(); - Arrays.stream(literalOperands).forEach(this.literalOperands::add); - } - } - - @Override - public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) - { - for (int i = 0; i < callBinding.operands().size(); i++) { - final SqlNode operand = callBinding.operands().get(i); - - if (literalOperands.contains(i)) { - // Verify that 'operand' is a literal. Allow CAST, since we can reduce these away later. - if (!SqlUtil.isLiteral(operand, true)) { - return throwOrReturn( - throwOnFailure, - callBinding, - cb -> cb.getValidator() - .newValidationError( - operand, - Static.RESOURCE.argumentMustBeLiteral(callBinding.getOperator().getName()) - ) - ); - } - } - - final RelDataType operandType = callBinding.getValidator().deriveType(callBinding.getScope(), operand); - final SqlTypeFamily expectedFamily = operandTypes.get(i); - - if (expectedFamily == SqlTypeFamily.ANY) { - // ANY matches anything. This operand is all good; do nothing. - } else if (expectedFamily.getTypeNames().contains(operandType.getSqlTypeName())) { - // Operand came in with one of the expected types. - } else if (operandType.getSqlTypeName() == SqlTypeName.NULL || SqlUtil.isNullLiteral(operand, true)) { - // Null came in, check if operand is a nullable type. - if (!nullableOperands.contains(i)) { - return throwOrReturn( - throwOnFailure, - callBinding, - cb -> cb.getValidator().newValidationError(operand, Static.RESOURCE.nullIllegal()) - ); - } - } else { - return throwOrReturn( - throwOnFailure, - callBinding, - SqlCallBinding::newValidationSignatureError - ); - } - } - - return true; - } - - @Override - public SqlOperandCountRange getOperandCountRange() - { - return SqlOperandCountRanges.between(requiredOperands, operandTypes.size()); - } - - @Override - public String getAllowedSignatures(SqlOperator op, String opName) - { - final List operands = !operandNames.isEmpty() ? operandNames : operandTypes; - final StringBuilder ret = new StringBuilder(); - ret.append("'"); - ret.append(opName); - ret.append("("); - for (int i = 0; i < operands.size(); i++) { - if (i > 0) { - ret.append(", "); - } - if (i >= requiredOperands) { - ret.append("["); - } - ret.append("<").append(operands.get(i)).append(">"); - } - for (int i = requiredOperands; i < operands.size(); i++) { - ret.append("]"); - } - ret.append(")'"); - return ret.toString(); - } - - @Override - public Consistency getConsistency() - { - return Consistency.NONE; - } - - @Override - public boolean isOptional(int i) - { - return i + 1 > requiredOperands; - } - } - public static boolean throwOrReturn( final boolean throwOnFailure, final SqlCallBinding callBinding, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ComplexDecodeBase64OperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ComplexDecodeBase64OperatorConversion.java index 94b90ed9af8..51bf88c363a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ComplexDecodeBase64OperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ComplexDecodeBase64OperatorConversion.java @@ -23,7 +23,6 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.druid.java.util.common.StringUtils; @@ -52,13 +51,10 @@ public class ComplexDecodeBase64OperatorConversion implements SqlOperatorConvers private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder(StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME)) - .operandTypeChecker( - OperandTypes.sequence( - "'" + StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME) + "(typeName, base64)'", - OperandTypes.and(OperandTypes.family(SqlTypeFamily.STRING), OperandTypes.LITERAL), - OperandTypes.ANY - ) - ) + .operandNames("typeName", "base64") + .operandTypes(SqlTypeFamily.STRING, SqlTypeFamily.ANY) + .requiredOperandCount(2) + .literalOperands(0) .returnTypeInference(ARBITRARY_COMPLEX_RETURN_TYPE_INFERENCE) .functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/NestedDataOperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/NestedDataOperatorConversions.java index ffaa42b1b65..2ba07c126af 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/NestedDataOperatorConversions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/NestedDataOperatorConversions.java @@ -118,13 +118,10 @@ public class NestedDataOperatorConversions { private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("JSON_KEYS") - .operandTypeChecker( - OperandTypes.sequence( - "'JSON_KEYS(expr, path)'", - OperandTypes.ANY, - OperandTypes.and(OperandTypes.family(SqlTypeFamily.STRING), OperandTypes.LITERAL) - ) - ) + .operandNames("expr", "path") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.STRING) + .literalOperands(1) + .requiredOperandCount(2) .functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION) .returnTypeNullableArrayWithNullableElements(SqlTypeName.VARCHAR) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/QueryLookupOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/QueryLookupOperatorConversion.java index 21d4bea356c..4f266e7a0e2 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/QueryLookupOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/QueryLookupOperatorConversion.java @@ -31,7 +31,6 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; import org.apache.druid.query.lookup.RegisteredLookupExtractionFn; import org.apache.druid.segment.column.RowSignature; -import org.apache.druid.sql.calcite.expression.BasicOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; @@ -43,16 +42,10 @@ public class QueryLookupOperatorConversion implements SqlOperatorConversion { private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("LOOKUP") - .operandTypeChecker( - BasicOperandTypeChecker.builder() - .operandTypes( - SqlTypeFamily.CHARACTER, - SqlTypeFamily.CHARACTER, - SqlTypeFamily.CHARACTER - ) - .requiredOperandCount(2) - .literalOperands(2) - .build()) + .operandNames("expr", "lookupName", "replaceMissingValueWith") + .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) + .requiredOperandCount(2) + .literalOperands(2) .returnTypeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/convertlet/TimeInIntervalConvertletFactory.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/convertlet/TimeInIntervalConvertletFactory.java index b99043e5fcb..e3dcf71879a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/convertlet/TimeInIntervalConvertletFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/convertlet/TimeInIntervalConvertletFactory.java @@ -27,7 +27,6 @@ import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql2rel.SqlRexContext; @@ -53,13 +52,10 @@ public class TimeInIntervalConvertletFactory implements DruidConvertletFactory private static final SqlOperator OPERATOR = OperatorConversions .operatorBuilder(NAME) - .operandTypeChecker( - OperandTypes.sequence( - "'" + NAME + "(, )'", - OperandTypes.family(SqlTypeFamily.TIMESTAMP), - OperandTypes.and(OperandTypes.family(SqlTypeFamily.CHARACTER), OperandTypes.LITERAL) - ) - ) + .operandNames("timestamp", "interval") + .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER) + .requiredOperandCount(2) + .literalOperands(1) .returnTypeNonNull(SqlTypeName.BOOLEAN) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java index 315d6da4fd9..f65c98bc1a1 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java @@ -4454,6 +4454,48 @@ public class CalciteNestedDataQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testGroupByCastedRootKeysJsonPath() + { + cannotVectorize(); + testQuery( + "SELECT " + + "JSON_KEYS(nester, CAST('$.' AS VARCHAR)), " + + "SUM(cnt) " + + "FROM druid.nested GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(DATA_SOURCE) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + new ExpressionVirtualColumn( + "v0", + "json_keys(\"nester\",'$.')", + ColumnType.STRING_ARRAY, + queryFramework().macroTable() + ) + ) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "d0", ColumnType.STRING_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{null, 5L}, + new Object[]{"[\"array\",\"n\"]", 2L} + ), + RowSignature.builder() + .add("EXPR$0", ColumnType.STRING_ARRAY) + .add("EXPR$1", ColumnType.LONG) + .build() + ); + } + @Test public void testGroupByAllPaths() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index b67db5dce41..702d73d9f9d 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -642,10 +642,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest testQuery( "SELECT " - + "EARLIEST(cnt), EARLIEST(m1), EARLIEST(dim1, 10), " + + "EARLIEST(cnt), EARLIEST(m1), EARLIEST(dim1, 10), EARLIEST(dim1, CAST(10 AS INTEGER)), " + "EARLIEST(cnt + 1), EARLIEST(m1 + 1), EARLIEST(dim1 || CAST(cnt AS VARCHAR), 10), " + "EARLIEST_BY(cnt, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1, MILLIS_TO_TIMESTAMP(l1), 10), " - + "EARLIEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) " + + "EARLIEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) " + "FROM druid.numfoo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -677,7 +677,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), ImmutableList.of( - new Object[]{1L, 1.0f, "", 2L, 2.0f, "1", 1L, 3.0f, "2", 2L, 4.0f, "21"} + new Object[]{1L, 1.0f, "", "", 2L, 2.0f, "1", 1L, 3.0f, "2", 2L, 4.0f, "21", "21"} ) ); } @@ -752,10 +752,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest testQuery( "SELECT " - + "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), " + + "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), LATEST(dim1, CAST(10 AS INTEGER)), " + "LATEST(cnt + 1), LATEST(m1 + 1), LATEST(dim1 || CAST(cnt AS VARCHAR), 10), " + "LATEST_BY(cnt, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1, MILLIS_TO_TIMESTAMP(l1), 10), " - + "LATEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) " + + "LATEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), CAST(10 AS INTEGER)) " + "FROM druid.numfoo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -787,7 +787,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), ImmutableList.of( - new Object[]{1L, 6.0f, "abc", 2L, 7.0f, "abc1", 1L, 2.0f, "10.1", 2L, 3.0f, "10.11"} + new Object[]{1L, 6.0f, "abc", "abc", 2L, 7.0f, "abc1", 1L, 2.0f, "10.1", 2L, 3.0f, "10.11", "10.11"} ) ); } @@ -5612,6 +5612,28 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testCountStarWithTimeInCastedIntervalFilter() + { + testQuery( + "SELECT COUNT(*) FROM druid.foo " + + "WHERE TIME_IN_INTERVAL(__time, CAST('2000-01-01/P1Y' AS VARCHAR)) " + + "AND TIME_IN_INTERVAL(CURRENT_TIMESTAMP, '2000/3000') -- Optimized away: always true", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Intervals.of("2000-01-01/2001-01-01"))) + .granularity(Granularities.ALL) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{3L} + ) + ); + } + @Test public void testCountStarWithTimeInIntervalFilterLosAngeles() { @@ -5661,7 +5683,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest expected -> { expected.expect(CoreMatchers.instanceOf(DruidException.class)); expected.expect(ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString( - "Argument to function 'TIME_IN_INTERVAL' must be a literal (line [1], column [38])"))); + "Argument to function 'TIME_IN_INTERVAL' must be a literal (line [1], column [63])"))); } ); } @@ -14076,6 +14098,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } } + @Test public void testComplexDecodeAgg() { @@ -14109,6 +14132,43 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ) ); } + + @Test + public void testComplexDecodeAggWithCastedTypeName() + { + msqIncompatible(); + cannotVectorize(); + testQuery( + "SELECT " + + "APPROX_COUNT_DISTINCT_BUILTIN(COMPLEX_DECODE_BASE64(CAST('hyperUnique' AS VARCHAR),PARSE_JSON(TO_JSON_STRING(unique_dim1)))) " + + "FROM druid.foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + expressionVirtualColumn( + "v0", + "complex_decode_base64('hyperUnique',parse_json(to_json_string(\"unique_dim1\")))", + ColumnType.ofComplex("hyperUnique") + ) + ) + .aggregators( + new HyperUniquesAggregatorFactory( + "a0", + "v0", + false, + true + ) + ) + .build() + ), + ImmutableList.of( + new Object[]{6L} + ) + ); + } + @NotYetSupported @Test public void testOrderByAlongWithInternalScanQuery() diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java index a240a2c2986..5a0c52514f3 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java @@ -20,7 +20,6 @@ package org.apache.druid.sql.calcite.expression; import com.google.common.collect.ImmutableList; -import it.unimi.dsi.fastutil.ints.IntSets; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.runtime.CalciteContextException; import org.apache.calcite.runtime.Resources.ExInst; @@ -38,7 +37,6 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.planner.DruidTypeSystem; import org.junit.Assert; import org.junit.Rule; @@ -51,7 +49,6 @@ import org.mockito.Mockito; import org.mockito.stubbing.Answer; import java.util.ArrayList; -import java.util.Collections; import java.util.List; @RunWith(Enclosed.class) @@ -65,13 +62,13 @@ public class OperatorConversionsTest @Test public void testGetOperandCountRange() { - SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker( - Collections.emptyList(), - ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER), - 2, - IntSets.EMPTY_SET, - null - ); + SqlOperandTypeChecker typeChecker = DefaultOperandTypeChecker + .builder() + .operandNames() + .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER) + .requiredOperandCount(2) + .literalOperands() + .build(); SqlOperandCountRange countRange = typeChecker.getOperandCountRange(); Assert.assertEquals(2, countRange.getMin()); Assert.assertEquals(3, countRange.getMax()); @@ -80,13 +77,13 @@ public class OperatorConversionsTest @Test public void testIsOptional() { - SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker( - Collections.emptyList(), - ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER), - 2, - IntSets.EMPTY_SET, - null - ); + SqlOperandTypeChecker typeChecker = DefaultOperandTypeChecker + .builder() + .operandNames() + .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER) + .requiredOperandCount(2) + .literalOperands() + .build(); Assert.assertFalse(typeChecker.isOptional(0)); Assert.assertFalse(typeChecker.isOptional(1)); Assert.assertTrue(typeChecker.isOptional(2)); @@ -121,7 +118,7 @@ public class OperatorConversionsTest { SqlFunction function = OperatorConversions .operatorBuilder("testRequiredOperandsOnly") - .operandTypeChecker(BasicOperandTypeChecker.builder().operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE).requiredOperandCount(1).build()) + .operandTypeChecker(DefaultOperandTypeChecker.builder().operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE).requiredOperandCount(1).build()) .returnTypeNonNull(SqlTypeName.CHAR) .build(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();