Allow casted literal values in SQL functions accepting literals (#15282)

Functions that accept literals also allow casted literals. This shouldn't have an impact on the queries that the user writes. It enables the SQL functions to accept explicit cast, which is required with JDBC.
This commit is contained in:
Laksh Singla 2023-11-01 10:38:48 +05:30 committed by GitHub
parent 49e0cba7ba
commit 2ea7177f15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 544 additions and 452 deletions

View File

@ -44,6 +44,7 @@ java.util.LinkedList @ Use ArrayList or ArrayDeque instead
java.util.Random#<init>() @ Use ThreadLocalRandom.current() or the constructor with a seed (the latter in tests only!)
java.lang.Math#random() @ Use ThreadLocalRandom.current()
java.util.regex.Pattern#matches(java.lang.String,java.lang.CharSequence) @ Use String.startsWith(), endsWith(), contains(), or compile and cache a Pattern explicitly
org.apache.calcite.sql.type.OperandTypes#LITERAL @ LITERAL type checker throws when literals with CAST are passed. Use org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker instead.
org.apache.commons.io.FileUtils#getTempDirectory() @ Use org.junit.rules.TemporaryFolder for tests instead
org.apache.commons.io.FileUtils#deleteDirectory(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils#deleteDirectory()
org.apache.commons.io.FileUtils#forceMkdir(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils.mkdirp instead

View File

@ -25,10 +25,10 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorFactory;
@ -37,6 +37,7 @@ import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
@ -133,8 +134,6 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE_WITH_COMPRESSION = "'" + NAME + "(column, compression)'";
TDigestGenerateSketchSqlAggFunction()
{
super(
@ -143,16 +142,19 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.OTHER),
null,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE_WITH_COMPRESSION, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
)
),
// Validation for signatures like 'TDIGEST_GENERATE_SKETCH(column)' and
// 'TDIGEST_GENERATE_SKETCH(column, compression)'
DefaultOperandTypeChecker
.builder()
.operandNames("column", "compression")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
.requiredOperandCount(1)
.literalOperands(1)
.build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
false,
Optionality.FORBIDDEN
);
}
}

View File

@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
@ -40,6 +40,7 @@ import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
@ -158,9 +159,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
private static class TDigestSketchQuantileSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE1 = "'" + NAME + "(column, quantile)'";
private static final String SIGNATURE2 = "'" + NAME + "(column, quantile, compression)'";
TDigestSketchQuantileSqlAggFunction()
{
super(
@ -169,19 +167,18 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.DOUBLE),
null,
OperandTypes.or(
OperandTypes.and(
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
),
OperandTypes.and(
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)
)
),
// Accounts for both 'TDIGEST_QUANTILE(column, quantile)' and 'TDIGEST_QUANTILE(column, quantile, compression)'
DefaultOperandTypeChecker
.builder()
.operandNames("column", "quantile", "compression")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)
.literalOperands(1, 2)
.requiredOperandCount(2)
.build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
false,
Optionality.FORBIDDEN
);
}
}

View File

@ -138,6 +138,76 @@ public class TDigestSketchSqlAggregatorTest extends BaseCalciteQueryTest
);
}
@Test
public void testCastedQuantileAndCompressionParamForTDigestQuantileAgg()
{
cannotVectorize();
testQuery(
"SELECT\n"
+ "TDIGEST_QUANTILE(m1, CAST(0.0 AS DOUBLE)), "
+ "TDIGEST_QUANTILE(m1, CAST(0.5 AS FLOAT), CAST(200 AS INTEGER)), "
+ "TDIGEST_QUANTILE(m1, CAST(1.0 AS DOUBLE), 300)\n"
+ "FROM foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(ImmutableList.of(
new TDigestSketchAggregatorFactory("a0:agg", "m1",
TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION
),
new TDigestSketchAggregatorFactory("a1:agg", "m1",
200
),
new TDigestSketchAggregatorFactory("a2:agg", "m1",
300
)
))
.postAggregators(
new TDigestSketchToQuantilePostAggregator("a0", makeFieldAccessPostAgg("a0:agg"), 0.0f),
new TDigestSketchToQuantilePostAggregator("a1", makeFieldAccessPostAgg("a1:agg"), 0.5f),
new TDigestSketchToQuantilePostAggregator("a2", makeFieldAccessPostAgg("a2:agg"), 1.0f)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ResultMatchMode.EQUALS_EPS,
ImmutableList.of(
new Object[]{1.0, 3.5, 6.0}
)
);
}
@Test
public void testComputingSketchOnNumericValuesWithCastedCompressionParameter()
{
cannotVectorize();
testQuery(
"SELECT\n"
+ "TDIGEST_GENERATE_SKETCH(m1, CAST(200 AS INTEGER))"
+ "FROM foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(ImmutableList.of(
new TDigestSketchAggregatorFactory("a0:agg", "m1", 200)
))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ResultMatchMode.EQUALS_EPS,
ImmutableList.of(
new String[]{
"\"AAAAAT/wAAAAAAAAQBgAAAAAAABAaQAAAAAAAAAAAAY/8AAAAAAAAD/wAAAAAAAAP/AAAAAAAABAAAAAAAAAAD/wAAAAAAAAQAgAAAAAAAA/8AAAAAAAAEAQAAAAAAAAP/AAAAAAAABAFAAAAAAAAD/wAAAAAAAAQBgAAAAAAAA=\""
}
)
);
}
@Test
public void testComputingSketchOnCastedString()
{

View File

@ -40,7 +40,6 @@ import org.apache.calcite.util.Static;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.BasicOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.expression.PostAggregatorVisitor;
@ -143,8 +142,8 @@ public abstract class DoublesSketchListArgBaseOperatorConversion implements SqlO
final RelDataType operandType = callBinding.getValidator().deriveType(callBinding.getScope(), operand);
// Verify that 'operand' is a literal number.
if (!SqlUtil.isLiteral(operand)) {
return BasicOperandTypeChecker.throwOrReturn(
if (!SqlUtil.isLiteral(operand, true)) {
return OperatorConversions.throwOrReturn(
throwOnFailure,
callBinding,
cb -> cb.getValidator()
@ -156,7 +155,7 @@ public abstract class DoublesSketchListArgBaseOperatorConversion implements SqlO
}
if (!SqlTypeFamily.NUMERIC.contains(operandType)) {
return BasicOperandTypeChecker.throwOrReturn(
return OperatorConversions.throwOrReturn(
throwOnFailure,
callBinding,
SqlCallBinding::newValidationSignatureError

View File

@ -495,6 +495,7 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest
+ " DS_GET_QUANTILE(DS_QUANTILES_SKETCH(cnt + 123), 0.5) + 1000,\n"
+ " ABS(DS_GET_QUANTILE(DS_QUANTILES_SKETCH(cnt), 0.5)),\n"
+ " DS_GET_QUANTILES(DS_QUANTILES_SKETCH(cnt), 0.5, 0.8),\n"
+ " DS_GET_QUANTILES(DS_QUANTILES_SKETCH(cnt), CAST(0.5 AS DOUBLE), CAST(0.8 AS DOUBLE)),\n"
+ " DS_HISTOGRAM(DS_QUANTILES_SKETCH(cnt), 0.2, 0.6),\n"
+ " DS_RANK(DS_QUANTILES_SKETCH(cnt), 3),\n"
+ " DS_CDF(DS_QUANTILES_SKETCH(cnt), 0.2, 0.6),\n"
@ -588,41 +589,49 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest
),
new double[]{0.5d, 0.8d}
),
new DoublesSketchToHistogramPostAggregator(
new DoublesSketchToQuantilesPostAggregator(
"p13",
new FieldAccessPostAggregator(
"p12",
"a2:agg"
),
new double[]{0.5d, 0.8d}
),
new DoublesSketchToHistogramPostAggregator(
"p15",
new FieldAccessPostAggregator(
"p14",
"a2:agg"
),
new double[]{0.2d, 0.6d},
null
),
new DoublesSketchToRankPostAggregator(
"p15",
new FieldAccessPostAggregator(
"p14",
"a2:agg"
),
3.0d
),
new DoublesSketchToCDFPostAggregator(
"p17",
new FieldAccessPostAggregator(
"p16",
"a2:agg"
),
new double[]{0.2d, 0.6d}
3.0d
),
new DoublesSketchToStringPostAggregator(
new DoublesSketchToCDFPostAggregator(
"p19",
new FieldAccessPostAggregator(
"p18",
"a2:agg"
),
new double[]{0.2d, 0.6d}
),
new DoublesSketchToStringPostAggregator(
"p21",
new FieldAccessPostAggregator(
"p20",
"a2:agg"
)
),
new ExpressionPostAggregator(
"p20",
"replace(replace(\"p19\",'HeapCompactDoublesSketch','HeapUpdateDoublesSketch'),"
"p22",
"replace(replace(\"p21\",'HeapCompactDoublesSketch','HeapUpdateDoublesSketch'),"
+ "'Combined Buffer Capacity : 6',"
+ "'Combined Buffer Capacity : 8')",
null,
@ -640,6 +649,7 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest
1124.0d,
1.0d,
"[1.0,1.0]",
"[1.0,1.0]",
"[0.0,0.0,6.0]",
1.0d,
"[0.0,0.0,1.0]",

View File

@ -25,10 +25,10 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.bloom.BloomFilterAggregatorFactory;
@ -38,6 +38,7 @@ import org.apache.druid.query.dimension.ExtractionDimensionSpec;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
@ -168,8 +169,6 @@ public class BloomFilterSqlAggregator implements SqlAggregator
private static class BloomFilterSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE1 = "'" + NAME + "(column, maxNumEntries)'";
BloomFilterSqlAggFunction()
{
super(
@ -178,13 +177,18 @@ public class BloomFilterSqlAggregator implements SqlAggregator
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.OTHER),
null,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
),
// Allow signatures like 'BLOOM_FILTER(column, maxNumEntries)'
DefaultOperandTypeChecker
.builder()
.operandNames("column", "maxNumEntries")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
.literalOperands(1)
.requiredOperandCount(2)
.build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
false,
Optionality.FORBIDDEN
);
}
}

View File

@ -125,7 +125,8 @@ public class BloomFilterSqlAggregatorTest extends BaseCalciteQueryTest
testQuery(
"SELECT\n"
+ "BLOOM_FILTER(dim1, 1000)\n"
+ "BLOOM_FILTER(dim1, 1000),\n"
+ "BLOOM_FILTER(dim1, CAST(1000 AS INTEGER))\n"
+ "FROM numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
@ -145,7 +146,10 @@ public class BloomFilterSqlAggregatorTest extends BaseCalciteQueryTest
.build()
),
ImmutableList.of(
new Object[]{queryFramework().queryJsonMapper().writeValueAsString(expected1)}
new Object[]{
queryFramework().queryJsonMapper().writeValueAsString(expected1),
queryFramework().queryJsonMapper().writeValueAsString(expected1)
}
)
);
}

View File

@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogram;
@ -39,6 +39,7 @@ import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
@ -221,15 +222,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
private static class FixedBucketsHistogramQuantileSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE1 =
"'"
+ NAME
+ "(column, probability, numBuckets, lowerLimit, upperLimit)'";
private static final String SIGNATURE2 =
"'"
+ NAME
+ "(column, probability, numBuckets, lowerLimit, upperLimit, outlierHandlingMode)'";
FixedBucketsHistogramQuantileSqlAggFunction()
{
super(
@ -238,47 +230,33 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.DOUBLE),
null,
OperandTypes.or(
OperandTypes.and(
OperandTypes.sequence(
SIGNATURE1,
OperandTypes.ANY,
OperandTypes.LITERAL,
OperandTypes.LITERAL,
OperandTypes.LITERAL,
OperandTypes.LITERAL
),
OperandTypes.family(
SqlTypeFamily.ANY,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.NUMERIC
)
),
OperandTypes.and(
OperandTypes.sequence(
SIGNATURE2,
OperandTypes.ANY,
OperandTypes.LITERAL,
OperandTypes.LITERAL,
OperandTypes.LITERAL,
OperandTypes.LITERAL,
OperandTypes.LITERAL
),
OperandTypes.family(
SqlTypeFamily.ANY,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.STRING
)
// Allows signatures like 'APPROX_QUANTILE_FIXED_BUCKETS(column, probability, numBuckets, lowerLimit, upperLimit)'
// and 'APPROX_QUANTILE_FIXED_BUCKETS(column, probability, numBuckets, lowerLimit, upperLimit, outlierHandlingMode)'
DefaultOperandTypeChecker
.builder()
.operandNames(
"column",
"probability",
"numBuckets",
"lowerLimit",
"upperLimit",
"outlierHandlingMode"
)
),
.operandTypes(
SqlTypeFamily.ANY,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.NUMERIC,
SqlTypeFamily.STRING
)
.literalOperands(1, 2, 3, 4, 5)
.requiredOperandCount(5)
.build(),
SqlFunctionCategory.NUMERIC,
false,
false
false,
Optionality.FORBIDDEN
);
}
}

View File

@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.histogram.ApproximateHistogram;
@ -41,6 +41,7 @@ import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
@ -196,9 +197,6 @@ public class QuantileSqlAggregator implements SqlAggregator
private static class QuantileSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE1 = "'" + NAME + "(column, probability)'";
private static final String SIGNATURE2 = "'" + NAME + "(column, probability, resolution)'";
QuantileSqlAggFunction()
{
super(
@ -207,19 +205,19 @@ public class QuantileSqlAggregator implements SqlAggregator
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.DOUBLE),
null,
OperandTypes.or(
OperandTypes.and(
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
),
OperandTypes.and(
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
)
),
// Checks for signatures like 'APPROX_QUANTILE(column, probability)' and
// 'APPROX_QUANTILE(column, probability, resolution)'
DefaultOperandTypeChecker
.builder()
.operandNames("column", "probability", "resolution")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
.requiredOperandCount(2)
.literalOperands(1, 2)
.build(),
SqlFunctionCategory.NUMERIC,
false,
false
false,
Optionality.FORBIDDEN
);
}
}

View File

@ -128,6 +128,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ
6.494999885559082,
5.497499942779541,
6.499499797821045,
6.499499797821045,
6.499499797821045,
6.499499797821045,
1.25
}
);
@ -142,6 +145,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.99, 20, 0.0, 10.0) FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0) FILTER(WHERE dim1 <> 'abc'),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0) FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'ignore') FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'clip') FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'overflow') FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(cnt, 0.5, 20, 0.0, 10.0)\n"
+ "FROM foo",
ImmutableList.of(
@ -200,8 +206,32 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ
),
not(equality("dim1", "abc", ColumnType.STRING))
),
new FilteredAggregatorFactory(
new FixedBucketsHistogramAggregatorFactory(
"a9:agg",
"m1",
20,
0.0d,
10.0d,
FixedBucketsHistogram.OutlierHandlingMode.CLIP,
false
),
equality("dim1", "abc", ColumnType.STRING)
),
new FilteredAggregatorFactory(
new FixedBucketsHistogramAggregatorFactory(
"a10:agg",
"m1",
20,
0.0d,
10.0d,
FixedBucketsHistogram.OutlierHandlingMode.OVERFLOW,
false
),
equality("dim1", "abc", ColumnType.STRING)
),
new FixedBucketsHistogramAggregatorFactory(
"a8:agg",
"a11:agg",
"cnt",
20,
0.0d,
@ -219,7 +249,55 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ
new QuantilePostAggregator("a5", "a5:agg", 0.99f),
new QuantilePostAggregator("a6", "a6:agg", 0.999f),
new QuantilePostAggregator("a7", "a5:agg", 0.999f),
new QuantilePostAggregator("a8", "a8:agg", 0.50f)
new QuantilePostAggregator("a8", "a5:agg", 0.999f),
new QuantilePostAggregator("a9", "a9:agg", 0.999f),
new QuantilePostAggregator("a10", "a10:agg", 0.999f),
new QuantilePostAggregator("a11", "a11:agg", 0.50f)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
expectedResults
);
}
@Test
public void testQuantileWithCastedLiteralArguments()
{
final List<Object[]> expectedResults = ImmutableList.of(new Object[]{6.499499797821045});
testQuery(
"SELECT\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS("
+ "m1, "
+ "CAST(0.999 AS DOUBLE), "
+ "CAST(20 AS INTEGER), "
+ "CAST(0.0 AS DOUBLE), "
+ "CAST(10.0 AS DOUBLE), "
+ "CAST('overflow' AS VARCHAR)"
+ ") "
+ "FILTER(WHERE dim1 = 'abc')\n"
+ "FROM foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(ImmutableList.of(
new FilteredAggregatorFactory(
new FixedBucketsHistogramAggregatorFactory(
"a0:agg",
"m1",
20,
0.0d,
10.0d,
FixedBucketsHistogram.OutlierHandlingMode.OVERFLOW,
false
),
equality("dim1", "abc", ColumnType.STRING)
)
))
.postAggregators(
new QuantilePostAggregator("a0", "a0:agg", 0.999f)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()

View File

@ -121,6 +121,7 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest
"SELECT\n"
+ "APPROX_QUANTILE(m1, 0.01),\n"
+ "APPROX_QUANTILE(m1, 0.5, 50),\n"
+ "APPROX_QUANTILE(m1, CAST(0.5 AS DOUBLE), CAST(50 AS INTEGER)),\n"
+ "APPROX_QUANTILE(m1, 0.98, 200),\n"
+ "APPROX_QUANTILE(m1, 0.99),\n"
+ "APPROX_QUANTILE(m1 * 2, 0.97),\n"
@ -144,28 +145,29 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest
)
.aggregators(ImmutableList.of(
new ApproximateHistogramAggregatorFactory("a0:agg", "m1", null, null, null, null, false),
new ApproximateHistogramAggregatorFactory("a2:agg", "m1", 200, null, null, null, false),
new ApproximateHistogramAggregatorFactory("a4:agg", "v0", null, null, null, null, false),
new ApproximateHistogramAggregatorFactory("a3:agg", "m1", 200, null, null, null, false),
new ApproximateHistogramAggregatorFactory("a5:agg", "v0", null, null, null, null, false),
new FilteredAggregatorFactory(
new ApproximateHistogramAggregatorFactory("a5:agg", "m1", null, null, null, null, false),
new ApproximateHistogramAggregatorFactory("a6:agg", "m1", null, null, null, null, false),
equality("dim1", "abc", ColumnType.STRING)
),
new FilteredAggregatorFactory(
new ApproximateHistogramAggregatorFactory("a6:agg", "m1", null, null, null, null, false),
new ApproximateHistogramAggregatorFactory("a7:agg", "m1", null, null, null, null, false),
not(equality("dim1", "abc", ColumnType.STRING))
),
new ApproximateHistogramAggregatorFactory("a8:agg", "cnt", null, null, null, null, false)
new ApproximateHistogramAggregatorFactory("a9:agg", "cnt", null, null, null, null, false)
))
.postAggregators(
new QuantilePostAggregator("a0", "a0:agg", 0.01f),
new QuantilePostAggregator("a1", "a0:agg", 0.50f),
new QuantilePostAggregator("a2", "a2:agg", 0.98f),
new QuantilePostAggregator("a3", "a0:agg", 0.99f),
new QuantilePostAggregator("a4", "a4:agg", 0.97f),
new QuantilePostAggregator("a5", "a5:agg", 0.99f),
new QuantilePostAggregator("a6", "a6:agg", 0.999f),
new QuantilePostAggregator("a7", "a5:agg", 0.999f),
new QuantilePostAggregator("a8", "a8:agg", 0.50f)
new QuantilePostAggregator("a2", "a0:agg", 0.50f),
new QuantilePostAggregator("a3", "a3:agg", 0.98f),
new QuantilePostAggregator("a4", "a0:agg", 0.99f),
new QuantilePostAggregator("a5", "a5:agg", 0.97f),
new QuantilePostAggregator("a6", "a6:agg", 0.99f),
new QuantilePostAggregator("a7", "a7:agg", 0.999f),
new QuantilePostAggregator("a8", "a6:agg", 0.999f),
new QuantilePostAggregator("a9", "a9:agg", 0.50f)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
@ -174,6 +176,7 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest
new Object[]{
1.0,
3.0,
3.0,
5.880000114440918,
5.940000057220459,
11.640000343322754,

View File

@ -33,8 +33,8 @@ import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.util.SqlVisitor;
@ -61,6 +61,7 @@ import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
@ -369,14 +370,13 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
SqlKind.OTHER_FUNCTION,
EARLIEST_LATEST_ARG0_RETURN_TYPE_INFERENCE,
InferTypes.RETURN_TYPE,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.sequence(
"'" + aggregatorType.name() + "(expr, maxBytesPerString)'",
OperandTypes.ANY,
OperandTypes.and(OperandTypes.NUMERIC, OperandTypes.LITERAL)
)
),
DefaultOperandTypeChecker
.builder()
.operandNames("expr", "maxBytesPerString")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
.requiredOperandCount(1)
.literalOperands(1)
.build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false,

View File

@ -26,7 +26,6 @@ import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.util.Optionality;
@ -38,6 +37,7 @@ import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregat
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
@ -176,19 +176,13 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator
SqlKind.OTHER_FUNCTION,
EARLIEST_LATEST_ARG0_RETURN_TYPE_INFERENCE,
InferTypes.RETURN_TYPE,
OperandTypes.or(
OperandTypes.sequence(
"'" + StringUtils.format("%s_BY", aggregatorType.name()) + "(expr, timeColumn)'",
OperandTypes.ANY,
OperandTypes.family(SqlTypeFamily.TIMESTAMP)
),
OperandTypes.sequence(
"'" + StringUtils.format("%s_BY", aggregatorType.name()) + "(expr, timeColumn, maxBytesPerString)'",
OperandTypes.ANY,
OperandTypes.family(SqlTypeFamily.TIMESTAMP),
OperandTypes.and(OperandTypes.NUMERIC, OperandTypes.LITERAL)
)
),
DefaultOperandTypeChecker
.builder()
.operandNames("expr", "timeColumn", "maxBytesPerString")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.NUMERIC)
.requiredOperandCount(2)
.literalOperands(2)
.build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false,

View File

@ -24,7 +24,6 @@ import it.unimi.dsi.fastutil.ints.IntArraySet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSets;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.runtime.CalciteException;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
@ -35,35 +34,47 @@ import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Static;
import org.apache.druid.java.util.common.ISE;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.IntStream;
/**
* Operand type checker that is used in simple situations: there are a particular number of operands, with
* Operand type checker that is used in 'simple' situations: there are a particular number of operands, with
* particular types, some of which may be optional or nullable, and some of which may be required to be literals.
*/
public class BasicOperandTypeChecker implements SqlOperandTypeChecker
public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
{
/**
* Operand names for {@link #getAllowedSignatures(SqlOperator, String)}. May be empty, in which case the
* {@link #operandTypes} are used instead.
*/
private final List<String> operandNames;
private final List<SqlTypeFamily> operandTypes;
private final int requiredOperands;
private final IntSet nullOperands;
private final IntSet nullableOperands;
private final IntSet literalOperands;
BasicOperandTypeChecker(
private DefaultOperandTypeChecker(
final List<String> operandNames,
final List<SqlTypeFamily> operandTypes,
final int requiredOperands,
final IntSet nullOperands,
final IntSet nullableOperands,
@Nullable final int[] literalOperands
)
{
Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0);
this.operandNames = Preconditions.checkNotNull(operandNames, "operandNames");
this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes");
this.requiredOperands = requiredOperands;
this.nullOperands = Preconditions.checkNotNull(nullOperands, "nullOperands");
this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands");
if (!operandNames.isEmpty() && operandNames.size() != operandTypes.size()) {
throw new ISE("Operand name count[%s] and type count[%s] must match", operandNames.size(), operandTypes.size());
}
if (literalOperands == null) {
this.literalOperands = IntSets.EMPTY_SET;
@ -78,19 +89,6 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker
return new Builder();
}
public static boolean throwOrReturn(
final boolean throwOnFailure,
final SqlCallBinding callBinding,
final Function<SqlCallBinding, CalciteException> exceptionMapper
)
{
if (throwOnFailure) {
throw exceptionMapper.apply(callBinding);
} else {
return false;
}
}
@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure)
{
@ -98,9 +96,9 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker
final SqlNode operand = callBinding.operands().get(i);
if (literalOperands.contains(i)) {
// Verify that 'operand' is a literal.
if (!SqlUtil.isLiteral(operand)) {
return throwOrReturn(
// Verify that 'operand' is a literal. Allow CAST, since we can reduce these away later.
if (!SqlUtil.isLiteral(operand, true)) {
return OperatorConversions.throwOrReturn(
throwOnFailure,
callBinding,
cb -> cb.getValidator()
@ -121,15 +119,15 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker
// Operand came in with one of the expected types.
} else if (operandType.getSqlTypeName() == SqlTypeName.NULL || SqlUtil.isNullLiteral(operand, true)) {
// Null came in, check if operand is a nullable type.
if (!nullOperands.contains(i)) {
return throwOrReturn(
if (!nullableOperands.contains(i)) {
return OperatorConversions.throwOrReturn(
throwOnFailure,
callBinding,
cb -> cb.getValidator().newValidationError(operand, Static.RESOURCE.nullIllegal())
);
}
} else {
return throwOrReturn(
return OperatorConversions.throwOrReturn(
throwOnFailure,
callBinding,
SqlCallBinding::newValidationSignatureError
@ -149,7 +147,25 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker
@Override
public String getAllowedSignatures(SqlOperator op, String opName)
{
return SqlUtil.getAliasedSignature(op, opName, operandTypes);
final List<?> operands = !operandNames.isEmpty() ? operandNames : operandTypes;
final StringBuilder ret = new StringBuilder();
ret.append("'");
ret.append(opName);
ret.append("(");
for (int i = 0; i < operands.size(); i++) {
if (i > 0) {
ret.append(", ");
}
if (i >= requiredOperands) {
ret.append("[");
}
ret.append("<").append(operands.get(i)).append(">");
}
for (int i = requiredOperands; i < operands.size(); i++) {
ret.append("]");
}
ret.append(")'");
return ret.toString();
}
@Override
@ -166,64 +182,70 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker
public static class Builder
{
private List<String> operandNames = Collections.emptyList();
private List<SqlTypeFamily> operandTypes;
private Integer requiredOperandCount = null;
private int[] literalOperands = null;
/**
* Signifies that a function accepts operands of type family given by {@param operandTypes}.
*/
@Nullable
private Integer requiredOperandCount;
private int[] literalOperands;
private Builder()
{
}
public Builder operandNames(final String... operandNames)
{
this.operandNames = Arrays.asList(operandNames);
return this;
}
public Builder operandNames(final List<String> operandNames)
{
this.operandNames = operandNames;
return this;
}
public Builder operandTypes(final SqlTypeFamily... operandTypes)
{
this.operandTypes = Arrays.asList(operandTypes);
return this;
}
/**
* Signifies that a function accepts operands of type family given by {@param operandTypes}.
*/
public Builder operandTypes(final List<SqlTypeFamily> operandTypes)
{
this.operandTypes = operandTypes;
return this;
}
/**
* Signifies that the first {@code requiredOperands} operands are required, and all later operands are optional.
*
* Required operands are not allowed to be null. Optional operands can either be skipped or explicitly provided as
* literal NULLs. For example, if {@code requiredOperands == 1}, then {@code F(x, NULL)} and {@code F(x)} are both
* accepted, and {@code x} must not be null.
*/
public Builder requiredOperandCount(final int requiredOperandCount)
public Builder requiredOperandCount(Integer requiredOperandCount)
{
this.requiredOperandCount = requiredOperandCount;
return this;
}
/**
* Signifies that the operands at positions given by {@code literalOperands} must be literals.
*/
public Builder literalOperands(final int... literalOperands)
{
this.literalOperands = literalOperands;
return this;
}
public BasicOperandTypeChecker build()
public DefaultOperandTypeChecker build()
{
// Create "nullableOperands" set including all optional arguments.
final IntSet nullableOperands = new IntArraySet();
if (requiredOperandCount != null) {
IntStream.range(requiredOperandCount, operandTypes.size()).forEach(nullableOperands::add);
}
return new BasicOperandTypeChecker(
int computedRequiredOperandCount = requiredOperandCount == null ? operandTypes.size() : requiredOperandCount;
return new DefaultOperandTypeChecker(
operandNames,
operandTypes,
requiredOperandCount == null ? operandTypes.size() : requiredOperandCount,
nullableOperands,
computedRequiredOperandCount,
DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size()),
literalOperands
);
}
}
public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount)
{
final IntSet nullableOperands = new IntArraySet();
IntStream.range(requiredOperandCount, totalOperandCount).forEach(nullableOperands::add);
return nullableOperands;
}
}

View File

@ -19,13 +19,11 @@
package org.apache.druid.sql.calcite.expression;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.IntArraySet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSets;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
@ -37,13 +35,9 @@ import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.type.BasicSqlType;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlOperandTypeInference;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
@ -51,7 +45,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.apache.calcite.util.Optionality;
import org.apache.calcite.util.Static;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
@ -69,7 +63,6 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.IntStream;
/**
* Utilities for assisting in writing {@link SqlOperatorConversion} implementations.
@ -334,6 +327,13 @@ public class OperatorConversions
return new AggregatorBuilder(name);
}
/**
* Helps in creating the operator builder along with validations and type inference for simple operator conversions.
*
* The type checker for this operator conversion can be either supplied manually, or an instance of
* {@link DefaultOperandTypeChecker} will be used if the user passes in {@link #operandTypes} and other optional
* parameters. Exactly one of them must be supplied to the builder.
*/
public static class OperatorBuilder<T extends SqlFunction>
{
protected final String name;
@ -491,8 +491,8 @@ public class OperatorConversions
}
/**
* Equivalent to calling {@link BasicOperandTypeChecker.Builder#operandTypes(SqlTypeFamily...)}; leads to using a
* {@link BasicOperandTypeChecker} as our operand type checker.
* Equivalent to calling {@link DefaultOperandTypeChecker.Builder#operandTypes(SqlTypeFamily...)}; leads to using a
* {@link DefaultOperandTypeChecker} as our operand type checker.
*
* May be used in conjunction with {@link #requiredOperandCount(int)} and {@link #literalOperands(int...)} in order
* to further refine operand checking logic.
@ -506,12 +506,9 @@ public class OperatorConversions
}
/**
* Equivalent to calling {@link BasicOperandTypeChecker.Builder#requiredOperandCount(int)}; leads to using a
* {@link BasicOperandTypeChecker} as our operand type checker.
*
* Not compatible with {@link #operandTypeChecker(SqlOperandTypeChecker)}.
* Equivalent to calling {@link DefaultOperandTypeChecker.Builder#requiredOperandCount(Integer)}; leads to using a
* {@link DefaultOperandTypeChecker} as our operand type checker.
*/
@Deprecated
public OperatorBuilder<T> requiredOperandCount(final int requiredOperandCount)
{
this.requiredOperandCount = requiredOperandCount;
@ -531,8 +528,8 @@ public class OperatorConversions
}
/**
* Equivalent to calling {@link BasicOperandTypeChecker.Builder#literalOperands(int...)}; leads to using a
* {@link BasicOperandTypeChecker} as our operand type checker.
* Equivalent to calling {@link DefaultOperandTypeChecker.Builder#literalOperands(int...)}; leads to using a
* {@link DefaultOperandTypeChecker} as our operand type checker.
*
* Not compatible with {@link #operandTypeChecker(SqlOperandTypeChecker)}.
*/
@ -554,37 +551,30 @@ public class OperatorConversions
@SuppressWarnings("unchecked")
public T build()
{
final IntSet nullableOperands = buildNullableOperands();
return (T) new SqlFunction(
name,
kind,
Preconditions.checkNotNull(returnTypeInference, "returnTypeInference"),
buildOperandTypeInference(nullableOperands),
buildOperandTypeChecker(nullableOperands),
buildOperandTypeInference(),
buildOperandTypeChecker(),
functionCategory
);
}
protected IntSet buildNullableOperands()
{
// Create "nullableOperands" set including all optional arguments.
final IntSet nullableOperands = new IntArraySet();
if (requiredOperandCount != null) {
IntStream.range(requiredOperandCount, operandTypes.size()).forEach(nullableOperands::add);
}
return nullableOperands;
}
protected SqlOperandTypeChecker buildOperandTypeChecker(final IntSet nullableOperands)
protected SqlOperandTypeChecker buildOperandTypeChecker()
{
if (operandTypeChecker == null) {
return new DefaultOperandTypeChecker(
operandNames,
operandTypes,
requiredOperandCount == null ? operandTypes.size() : requiredOperandCount,
nullableOperands,
literalOperands
);
if (operandTypes == null) {
throw DruidException.defensive(
"'operandTypes' must be non null if 'operandTypeChecker' is not passed to the operator conversion.");
}
return DefaultOperandTypeChecker
.builder()
.operandNames(operandNames)
.operandTypes(operandTypes)
.requiredOperandCount(requiredOperandCount == null ? operandTypes.size() : requiredOperandCount)
.literalOperands(literalOperands)
.build();
} else if (operandNames.isEmpty()
&& operandTypes == null
&& requiredOperandCount == null
@ -598,8 +588,11 @@ public class OperatorConversions
}
}
protected SqlOperandTypeInference buildOperandTypeInference(final IntSet nullableOperands)
protected SqlOperandTypeInference buildOperandTypeInference()
{
final IntSet nullableOperands = requiredOperandCount == null
? new IntArraySet()
: DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size());
if (operandTypeInference == null) {
SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands);
return (callBinding, returnType, types) -> {
@ -634,9 +627,8 @@ public class OperatorConversions
@Override
public SqlAggFunction build()
{
final IntSet nullableOperands = buildNullableOperands();
final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference(nullableOperands);
final SqlOperandTypeChecker operandTypeChecker = buildOperandTypeChecker(nullableOperands);
final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference();
final SqlOperandTypeChecker operandTypeChecker = buildOperandTypeChecker();
class DruidSqlAggFunction extends SqlAggFunction
{
@ -735,147 +727,6 @@ public class OperatorConversions
}
}
/**
* Operand type checker that is used in 'simple' situations: there are a particular number of operands, with
* particular types, some of which may be optional or nullable, and some of which may be required to be literals.
*/
@VisibleForTesting
static class DefaultOperandTypeChecker implements SqlOperandTypeChecker
{
/**
* Operand names for {@link #getAllowedSignatures(SqlOperator, String)}. May be empty, in which case the
* {@link #operandTypes} are used instead.
*/
private final List<String> operandNames;
private final List<SqlTypeFamily> operandTypes;
private final int requiredOperands;
private final IntSet nullableOperands;
private final IntSet literalOperands;
public int getNumberOfLiteralOperands()
{
return literalOperands.size();
}
@VisibleForTesting
DefaultOperandTypeChecker(
final List<String> operandNames,
final List<SqlTypeFamily> operandTypes,
final int requiredOperands,
final IntSet nullableOperands,
@Nullable final int[] literalOperands
)
{
Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0);
this.operandNames = Preconditions.checkNotNull(operandNames, "operandNames");
this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes");
this.requiredOperands = requiredOperands;
this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands");
if (!operandNames.isEmpty() && operandNames.size() != operandTypes.size()) {
throw new ISE("Operand name count[%s] and type count[%s] must match", operandNames.size(), operandTypes.size());
}
if (literalOperands == null) {
this.literalOperands = IntSets.EMPTY_SET;
} else {
this.literalOperands = new IntArraySet();
Arrays.stream(literalOperands).forEach(this.literalOperands::add);
}
}
@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure)
{
for (int i = 0; i < callBinding.operands().size(); i++) {
final SqlNode operand = callBinding.operands().get(i);
if (literalOperands.contains(i)) {
// Verify that 'operand' is a literal. Allow CAST, since we can reduce these away later.
if (!SqlUtil.isLiteral(operand, true)) {
return throwOrReturn(
throwOnFailure,
callBinding,
cb -> cb.getValidator()
.newValidationError(
operand,
Static.RESOURCE.argumentMustBeLiteral(callBinding.getOperator().getName())
)
);
}
}
final RelDataType operandType = callBinding.getValidator().deriveType(callBinding.getScope(), operand);
final SqlTypeFamily expectedFamily = operandTypes.get(i);
if (expectedFamily == SqlTypeFamily.ANY) {
// ANY matches anything. This operand is all good; do nothing.
} else if (expectedFamily.getTypeNames().contains(operandType.getSqlTypeName())) {
// Operand came in with one of the expected types.
} else if (operandType.getSqlTypeName() == SqlTypeName.NULL || SqlUtil.isNullLiteral(operand, true)) {
// Null came in, check if operand is a nullable type.
if (!nullableOperands.contains(i)) {
return throwOrReturn(
throwOnFailure,
callBinding,
cb -> cb.getValidator().newValidationError(operand, Static.RESOURCE.nullIllegal())
);
}
} else {
return throwOrReturn(
throwOnFailure,
callBinding,
SqlCallBinding::newValidationSignatureError
);
}
}
return true;
}
@Override
public SqlOperandCountRange getOperandCountRange()
{
return SqlOperandCountRanges.between(requiredOperands, operandTypes.size());
}
@Override
public String getAllowedSignatures(SqlOperator op, String opName)
{
final List<?> operands = !operandNames.isEmpty() ? operandNames : operandTypes;
final StringBuilder ret = new StringBuilder();
ret.append("'");
ret.append(opName);
ret.append("(");
for (int i = 0; i < operands.size(); i++) {
if (i > 0) {
ret.append(", ");
}
if (i >= requiredOperands) {
ret.append("[");
}
ret.append("<").append(operands.get(i)).append(">");
}
for (int i = requiredOperands; i < operands.size(); i++) {
ret.append("]");
}
ret.append(")'");
return ret.toString();
}
@Override
public Consistency getConsistency()
{
return Consistency.NONE;
}
@Override
public boolean isOptional(int i)
{
return i + 1 > requiredOperands;
}
}
public static boolean throwOrReturn(
final boolean throwOnFailure,
final SqlCallBinding callBinding,

View File

@ -23,7 +23,6 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.java.util.common.StringUtils;
@ -52,13 +51,10 @@ public class ComplexDecodeBase64OperatorConversion implements SqlOperatorConvers
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME))
.operandTypeChecker(
OperandTypes.sequence(
"'" + StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME) + "(typeName, base64)'",
OperandTypes.and(OperandTypes.family(SqlTypeFamily.STRING), OperandTypes.LITERAL),
OperandTypes.ANY
)
)
.operandNames("typeName", "base64")
.operandTypes(SqlTypeFamily.STRING, SqlTypeFamily.ANY)
.requiredOperandCount(2)
.literalOperands(0)
.returnTypeInference(ARBITRARY_COMPLEX_RETURN_TYPE_INFERENCE)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.build();

View File

@ -118,13 +118,10 @@ public class NestedDataOperatorConversions
{
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("JSON_KEYS")
.operandTypeChecker(
OperandTypes.sequence(
"'JSON_KEYS(expr, path)'",
OperandTypes.ANY,
OperandTypes.and(OperandTypes.family(SqlTypeFamily.STRING), OperandTypes.LITERAL)
)
)
.operandNames("expr", "path")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.STRING)
.literalOperands(1)
.requiredOperandCount(2)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.returnTypeNullableArrayWithNullableElements(SqlTypeName.VARCHAR)
.build();

View File

@ -31,7 +31,6 @@ import org.apache.druid.math.expr.Expr;
import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider;
import org.apache.druid.query.lookup.RegisteredLookupExtractionFn;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.BasicOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.expression.SqlOperatorConversion;
@ -43,16 +42,10 @@ public class QueryLookupOperatorConversion implements SqlOperatorConversion
{
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("LOOKUP")
.operandTypeChecker(
BasicOperandTypeChecker.builder()
.operandTypes(
SqlTypeFamily.CHARACTER,
SqlTypeFamily.CHARACTER,
SqlTypeFamily.CHARACTER
)
.requiredOperandCount(2)
.literalOperands(2)
.build())
.operandNames("expr", "lookupName", "replaceMissingValueWith")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperandCount(2)
.literalOperands(2)
.returnTypeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.build();

View File

@ -27,7 +27,6 @@ import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql2rel.SqlRexContext;
@ -53,13 +52,10 @@ public class TimeInIntervalConvertletFactory implements DruidConvertletFactory
private static final SqlOperator OPERATOR = OperatorConversions
.operatorBuilder(NAME)
.operandTypeChecker(
OperandTypes.sequence(
"'" + NAME + "(<TIMESTAMP>, <LITERAL ISO8601 INTERVAL>)'",
OperandTypes.family(SqlTypeFamily.TIMESTAMP),
OperandTypes.and(OperandTypes.family(SqlTypeFamily.CHARACTER), OperandTypes.LITERAL)
)
)
.operandNames("timestamp", "interval")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)
.requiredOperandCount(2)
.literalOperands(1)
.returnTypeNonNull(SqlTypeName.BOOLEAN)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -4454,6 +4454,48 @@ public class CalciteNestedDataQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testGroupByCastedRootKeysJsonPath()
{
cannotVectorize();
testQuery(
"SELECT "
+ "JSON_KEYS(nester, CAST('$.' AS VARCHAR)), "
+ "SUM(cnt) "
+ "FROM druid.nested GROUP BY 1",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(DATA_SOURCE)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(
new ExpressionVirtualColumn(
"v0",
"json_keys(\"nester\",'$.')",
ColumnType.STRING_ARRAY,
queryFramework().macroTable()
)
)
.setDimensions(
dimensions(
new DefaultDimensionSpec("v0", "d0", ColumnType.STRING_ARRAY)
)
)
.setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{null, 5L},
new Object[]{"[\"array\",\"n\"]", 2L}
),
RowSignature.builder()
.add("EXPR$0", ColumnType.STRING_ARRAY)
.add("EXPR$1", ColumnType.LONG)
.build()
);
}
@Test
public void testGroupByAllPaths()
{

View File

@ -642,10 +642,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
testQuery(
"SELECT "
+ "EARLIEST(cnt), EARLIEST(m1), EARLIEST(dim1, 10), "
+ "EARLIEST(cnt), EARLIEST(m1), EARLIEST(dim1, 10), EARLIEST(dim1, CAST(10 AS INTEGER)), "
+ "EARLIEST(cnt + 1), EARLIEST(m1 + 1), EARLIEST(dim1 || CAST(cnt AS VARCHAR), 10), "
+ "EARLIEST_BY(cnt, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1, MILLIS_TO_TIMESTAMP(l1), 10), "
+ "EARLIEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) "
+ "EARLIEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) "
+ "FROM druid.numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
@ -677,7 +677,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()
),
ImmutableList.of(
new Object[]{1L, 1.0f, "", 2L, 2.0f, "1", 1L, 3.0f, "2", 2L, 4.0f, "21"}
new Object[]{1L, 1.0f, "", "", 2L, 2.0f, "1", 1L, 3.0f, "2", 2L, 4.0f, "21", "21"}
)
);
}
@ -752,10 +752,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
testQuery(
"SELECT "
+ "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), "
+ "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), LATEST(dim1, CAST(10 AS INTEGER)), "
+ "LATEST(cnt + 1), LATEST(m1 + 1), LATEST(dim1 || CAST(cnt AS VARCHAR), 10), "
+ "LATEST_BY(cnt, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1, MILLIS_TO_TIMESTAMP(l1), 10), "
+ "LATEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) "
+ "LATEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), CAST(10 AS INTEGER)) "
+ "FROM druid.numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
@ -787,7 +787,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()
),
ImmutableList.of(
new Object[]{1L, 6.0f, "abc", 2L, 7.0f, "abc1", 1L, 2.0f, "10.1", 2L, 3.0f, "10.11"}
new Object[]{1L, 6.0f, "abc", "abc", 2L, 7.0f, "abc1", 1L, 2.0f, "10.1", 2L, 3.0f, "10.11", "10.11"}
)
);
}
@ -5612,6 +5612,28 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testCountStarWithTimeInCastedIntervalFilter()
{
testQuery(
"SELECT COUNT(*) FROM druid.foo "
+ "WHERE TIME_IN_INTERVAL(__time, CAST('2000-01-01/P1Y' AS VARCHAR)) "
+ "AND TIME_IN_INTERVAL(CURRENT_TIMESTAMP, '2000/3000') -- Optimized away: always true",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.of("2000-01-01/2001-01-01")))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{3L}
)
);
}
@Test
public void testCountStarWithTimeInIntervalFilterLosAngeles()
{
@ -5661,7 +5683,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
expected -> {
expected.expect(CoreMatchers.instanceOf(DruidException.class));
expected.expect(ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
"Argument to function 'TIME_IN_INTERVAL' must be a literal (line [1], column [38])")));
"Argument to function 'TIME_IN_INTERVAL' must be a literal (line [1], column [63])")));
}
);
}
@ -14076,6 +14098,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
}
@Test
public void testComplexDecodeAgg()
{
@ -14109,6 +14132,43 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
)
);
}
@Test
public void testComplexDecodeAggWithCastedTypeName()
{
msqIncompatible();
cannotVectorize();
testQuery(
"SELECT "
+ "APPROX_COUNT_DISTINCT_BUILTIN(COMPLEX_DECODE_BASE64(CAST('hyperUnique' AS VARCHAR),PARSE_JSON(TO_JSON_STRING(unique_dim1)))) "
+ "FROM druid.foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(
expressionVirtualColumn(
"v0",
"complex_decode_base64('hyperUnique',parse_json(to_json_string(\"unique_dim1\")))",
ColumnType.ofComplex("hyperUnique")
)
)
.aggregators(
new HyperUniquesAggregatorFactory(
"a0",
"v0",
false,
true
)
)
.build()
),
ImmutableList.of(
new Object[]{6L}
)
);
}
@NotYetSupported
@Test
public void testOrderByAlongWithInternalScanQuery()

View File

@ -20,7 +20,6 @@
package org.apache.druid.sql.calcite.expression;
import com.google.common.collect.ImmutableList;
import it.unimi.dsi.fastutil.ints.IntSets;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.runtime.CalciteContextException;
import org.apache.calcite.runtime.Resources.ExInst;
@ -38,7 +37,6 @@ import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.planner.DruidTypeSystem;
import org.junit.Assert;
import org.junit.Rule;
@ -51,7 +49,6 @@ import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@RunWith(Enclosed.class)
@ -65,13 +62,13 @@ public class OperatorConversionsTest
@Test
public void testGetOperandCountRange()
{
SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
Collections.emptyList(),
ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
2,
IntSets.EMPTY_SET,
null
);
SqlOperandTypeChecker typeChecker = DefaultOperandTypeChecker
.builder()
.operandNames()
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
.requiredOperandCount(2)
.literalOperands()
.build();
SqlOperandCountRange countRange = typeChecker.getOperandCountRange();
Assert.assertEquals(2, countRange.getMin());
Assert.assertEquals(3, countRange.getMax());
@ -80,13 +77,13 @@ public class OperatorConversionsTest
@Test
public void testIsOptional()
{
SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
Collections.emptyList(),
ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
2,
IntSets.EMPTY_SET,
null
);
SqlOperandTypeChecker typeChecker = DefaultOperandTypeChecker
.builder()
.operandNames()
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
.requiredOperandCount(2)
.literalOperands()
.build();
Assert.assertFalse(typeChecker.isOptional(0));
Assert.assertFalse(typeChecker.isOptional(1));
Assert.assertTrue(typeChecker.isOptional(2));
@ -121,7 +118,7 @@ public class OperatorConversionsTest
{
SqlFunction function = OperatorConversions
.operatorBuilder("testRequiredOperandsOnly")
.operandTypeChecker(BasicOperandTypeChecker.builder().operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE).requiredOperandCount(1).build())
.operandTypeChecker(DefaultOperandTypeChecker.builder().operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE).requiredOperandCount(1).build())
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();