mirror of https://github.com/apache/druid.git
Allow casted literal values in SQL functions accepting literals (#15282)
Functions that accept literals also allow casted literals. This shouldn't have an impact on the queries that the user writes. It enables the SQL functions to accept explicit cast, which is required with JDBC.
This commit is contained in:
parent
49e0cba7ba
commit
2ea7177f15
|
@ -44,6 +44,7 @@ java.util.LinkedList @ Use ArrayList or ArrayDeque instead
|
|||
java.util.Random#<init>() @ Use ThreadLocalRandom.current() or the constructor with a seed (the latter in tests only!)
|
||||
java.lang.Math#random() @ Use ThreadLocalRandom.current()
|
||||
java.util.regex.Pattern#matches(java.lang.String,java.lang.CharSequence) @ Use String.startsWith(), endsWith(), contains(), or compile and cache a Pattern explicitly
|
||||
org.apache.calcite.sql.type.OperandTypes#LITERAL @ LITERAL type checker throws when literals with CAST are passed. Use org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker instead.
|
||||
org.apache.commons.io.FileUtils#getTempDirectory() @ Use org.junit.rules.TemporaryFolder for tests instead
|
||||
org.apache.commons.io.FileUtils#deleteDirectory(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils#deleteDirectory()
|
||||
org.apache.commons.io.FileUtils#forceMkdir(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils.mkdirp instead
|
||||
|
|
|
@ -25,10 +25,10 @@ import org.apache.calcite.rex.RexNode;
|
|||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.SqlFunctionCategory;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.ReturnTypes;
|
||||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.calcite.util.Optionality;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorFactory;
|
||||
|
@ -37,6 +37,7 @@ import org.apache.druid.segment.column.ColumnType;
|
|||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregations;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.planner.PlannerContext;
|
||||
import org.apache.druid.sql.calcite.rel.InputAccessor;
|
||||
|
@ -133,8 +134,6 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
|
|||
|
||||
private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction
|
||||
{
|
||||
private static final String SIGNATURE_WITH_COMPRESSION = "'" + NAME + "(column, compression)'";
|
||||
|
||||
TDigestGenerateSketchSqlAggFunction()
|
||||
{
|
||||
super(
|
||||
|
@ -143,16 +142,19 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
|
|||
SqlKind.OTHER_FUNCTION,
|
||||
ReturnTypes.explicit(SqlTypeName.OTHER),
|
||||
null,
|
||||
OperandTypes.or(
|
||||
OperandTypes.ANY,
|
||||
OperandTypes.and(
|
||||
OperandTypes.sequence(SIGNATURE_WITH_COMPRESSION, OperandTypes.ANY, OperandTypes.LITERAL),
|
||||
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
|
||||
)
|
||||
),
|
||||
// Validation for signatures like 'TDIGEST_GENERATE_SKETCH(column)' and
|
||||
// 'TDIGEST_GENERATE_SKETCH(column, compression)'
|
||||
DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames("column", "compression")
|
||||
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
|
||||
.requiredOperandCount(1)
|
||||
.literalOperands(1)
|
||||
.build(),
|
||||
SqlFunctionCategory.USER_DEFINED_FUNCTION,
|
||||
false,
|
||||
false
|
||||
false,
|
||||
Optionality.FORBIDDEN
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode;
|
|||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.SqlFunctionCategory;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.ReturnTypes;
|
||||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.calcite.util.Optionality;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
|
||||
|
@ -40,6 +40,7 @@ import org.apache.druid.segment.column.ColumnType;
|
|||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregations;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.planner.PlannerContext;
|
||||
import org.apache.druid.sql.calcite.rel.InputAccessor;
|
||||
|
@ -158,9 +159,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
|
|||
|
||||
private static class TDigestSketchQuantileSqlAggFunction extends SqlAggFunction
|
||||
{
|
||||
private static final String SIGNATURE1 = "'" + NAME + "(column, quantile)'";
|
||||
private static final String SIGNATURE2 = "'" + NAME + "(column, quantile, compression)'";
|
||||
|
||||
TDigestSketchQuantileSqlAggFunction()
|
||||
{
|
||||
super(
|
||||
|
@ -169,19 +167,18 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
|
|||
SqlKind.OTHER_FUNCTION,
|
||||
ReturnTypes.explicit(SqlTypeName.DOUBLE),
|
||||
null,
|
||||
OperandTypes.or(
|
||||
OperandTypes.and(
|
||||
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
|
||||
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
|
||||
),
|
||||
OperandTypes.and(
|
||||
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
|
||||
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)
|
||||
)
|
||||
),
|
||||
// Accounts for both 'TDIGEST_QUANTILE(column, quantile)' and 'TDIGEST_QUANTILE(column, quantile, compression)'
|
||||
DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames("column", "quantile", "compression")
|
||||
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)
|
||||
.literalOperands(1, 2)
|
||||
.requiredOperandCount(2)
|
||||
.build(),
|
||||
SqlFunctionCategory.USER_DEFINED_FUNCTION,
|
||||
false,
|
||||
false
|
||||
false,
|
||||
Optionality.FORBIDDEN
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -138,6 +138,76 @@ public class TDigestSketchSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCastedQuantileAndCompressionParamForTDigestQuantileAgg()
|
||||
{
|
||||
cannotVectorize();
|
||||
testQuery(
|
||||
"SELECT\n"
|
||||
+ "TDIGEST_QUANTILE(m1, CAST(0.0 AS DOUBLE)), "
|
||||
+ "TDIGEST_QUANTILE(m1, CAST(0.5 AS FLOAT), CAST(200 AS INTEGER)), "
|
||||
+ "TDIGEST_QUANTILE(m1, CAST(1.0 AS DOUBLE), 300)\n"
|
||||
+ "FROM foo",
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(ImmutableList.of(
|
||||
new TDigestSketchAggregatorFactory("a0:agg", "m1",
|
||||
TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION
|
||||
),
|
||||
new TDigestSketchAggregatorFactory("a1:agg", "m1",
|
||||
200
|
||||
),
|
||||
new TDigestSketchAggregatorFactory("a2:agg", "m1",
|
||||
300
|
||||
)
|
||||
))
|
||||
.postAggregators(
|
||||
new TDigestSketchToQuantilePostAggregator("a0", makeFieldAccessPostAgg("a0:agg"), 0.0f),
|
||||
new TDigestSketchToQuantilePostAggregator("a1", makeFieldAccessPostAgg("a1:agg"), 0.5f),
|
||||
new TDigestSketchToQuantilePostAggregator("a2", makeFieldAccessPostAgg("a2:agg"), 1.0f)
|
||||
)
|
||||
.context(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ResultMatchMode.EQUALS_EPS,
|
||||
ImmutableList.of(
|
||||
new Object[]{1.0, 3.5, 6.0}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputingSketchOnNumericValuesWithCastedCompressionParameter()
|
||||
{
|
||||
cannotVectorize();
|
||||
|
||||
testQuery(
|
||||
"SELECT\n"
|
||||
+ "TDIGEST_GENERATE_SKETCH(m1, CAST(200 AS INTEGER))"
|
||||
+ "FROM foo",
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(ImmutableList.of(
|
||||
new TDigestSketchAggregatorFactory("a0:agg", "m1", 200)
|
||||
))
|
||||
.context(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ResultMatchMode.EQUALS_EPS,
|
||||
ImmutableList.of(
|
||||
new String[]{
|
||||
"\"AAAAAT/wAAAAAAAAQBgAAAAAAABAaQAAAAAAAAAAAAY/8AAAAAAAAD/wAAAAAAAAP/AAAAAAAABAAAAAAAAAAD/wAAAAAAAAQAgAAAAAAAA/8AAAAAAAAEAQAAAAAAAAP/AAAAAAAABAFAAAAAAAAD/wAAAAAAAAQBgAAAAAAAA=\""
|
||||
}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputingSketchOnCastedString()
|
||||
{
|
||||
|
|
|
@ -40,7 +40,6 @@ import org.apache.calcite.util.Static;
|
|||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.PostAggregator;
|
||||
import org.apache.druid.segment.column.RowSignature;
|
||||
import org.apache.druid.sql.calcite.expression.BasicOperandTypeChecker;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.expression.OperatorConversions;
|
||||
import org.apache.druid.sql.calcite.expression.PostAggregatorVisitor;
|
||||
|
@ -143,8 +142,8 @@ public abstract class DoublesSketchListArgBaseOperatorConversion implements SqlO
|
|||
final RelDataType operandType = callBinding.getValidator().deriveType(callBinding.getScope(), operand);
|
||||
|
||||
// Verify that 'operand' is a literal number.
|
||||
if (!SqlUtil.isLiteral(operand)) {
|
||||
return BasicOperandTypeChecker.throwOrReturn(
|
||||
if (!SqlUtil.isLiteral(operand, true)) {
|
||||
return OperatorConversions.throwOrReturn(
|
||||
throwOnFailure,
|
||||
callBinding,
|
||||
cb -> cb.getValidator()
|
||||
|
@ -156,7 +155,7 @@ public abstract class DoublesSketchListArgBaseOperatorConversion implements SqlO
|
|||
}
|
||||
|
||||
if (!SqlTypeFamily.NUMERIC.contains(operandType)) {
|
||||
return BasicOperandTypeChecker.throwOrReturn(
|
||||
return OperatorConversions.throwOrReturn(
|
||||
throwOnFailure,
|
||||
callBinding,
|
||||
SqlCallBinding::newValidationSignatureError
|
||||
|
|
|
@ -495,6 +495,7 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
+ " DS_GET_QUANTILE(DS_QUANTILES_SKETCH(cnt + 123), 0.5) + 1000,\n"
|
||||
+ " ABS(DS_GET_QUANTILE(DS_QUANTILES_SKETCH(cnt), 0.5)),\n"
|
||||
+ " DS_GET_QUANTILES(DS_QUANTILES_SKETCH(cnt), 0.5, 0.8),\n"
|
||||
+ " DS_GET_QUANTILES(DS_QUANTILES_SKETCH(cnt), CAST(0.5 AS DOUBLE), CAST(0.8 AS DOUBLE)),\n"
|
||||
+ " DS_HISTOGRAM(DS_QUANTILES_SKETCH(cnt), 0.2, 0.6),\n"
|
||||
+ " DS_RANK(DS_QUANTILES_SKETCH(cnt), 3),\n"
|
||||
+ " DS_CDF(DS_QUANTILES_SKETCH(cnt), 0.2, 0.6),\n"
|
||||
|
@ -588,41 +589,49 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
),
|
||||
new double[]{0.5d, 0.8d}
|
||||
),
|
||||
new DoublesSketchToHistogramPostAggregator(
|
||||
new DoublesSketchToQuantilesPostAggregator(
|
||||
"p13",
|
||||
new FieldAccessPostAggregator(
|
||||
"p12",
|
||||
"a2:agg"
|
||||
),
|
||||
new double[]{0.5d, 0.8d}
|
||||
),
|
||||
new DoublesSketchToHistogramPostAggregator(
|
||||
"p15",
|
||||
new FieldAccessPostAggregator(
|
||||
"p14",
|
||||
"a2:agg"
|
||||
),
|
||||
new double[]{0.2d, 0.6d},
|
||||
null
|
||||
),
|
||||
new DoublesSketchToRankPostAggregator(
|
||||
"p15",
|
||||
new FieldAccessPostAggregator(
|
||||
"p14",
|
||||
"a2:agg"
|
||||
),
|
||||
3.0d
|
||||
),
|
||||
new DoublesSketchToCDFPostAggregator(
|
||||
"p17",
|
||||
new FieldAccessPostAggregator(
|
||||
"p16",
|
||||
"a2:agg"
|
||||
),
|
||||
new double[]{0.2d, 0.6d}
|
||||
3.0d
|
||||
),
|
||||
new DoublesSketchToStringPostAggregator(
|
||||
new DoublesSketchToCDFPostAggregator(
|
||||
"p19",
|
||||
new FieldAccessPostAggregator(
|
||||
"p18",
|
||||
"a2:agg"
|
||||
),
|
||||
new double[]{0.2d, 0.6d}
|
||||
),
|
||||
new DoublesSketchToStringPostAggregator(
|
||||
"p21",
|
||||
new FieldAccessPostAggregator(
|
||||
"p20",
|
||||
"a2:agg"
|
||||
)
|
||||
),
|
||||
new ExpressionPostAggregator(
|
||||
"p20",
|
||||
"replace(replace(\"p19\",'HeapCompactDoublesSketch','HeapUpdateDoublesSketch'),"
|
||||
"p22",
|
||||
"replace(replace(\"p21\",'HeapCompactDoublesSketch','HeapUpdateDoublesSketch'),"
|
||||
+ "'Combined Buffer Capacity : 6',"
|
||||
+ "'Combined Buffer Capacity : 8')",
|
||||
null,
|
||||
|
@ -640,6 +649,7 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
1124.0d,
|
||||
1.0d,
|
||||
"[1.0,1.0]",
|
||||
"[1.0,1.0]",
|
||||
"[0.0,0.0,6.0]",
|
||||
1.0d,
|
||||
"[0.0,0.0,1.0]",
|
||||
|
|
|
@ -25,10 +25,10 @@ import org.apache.calcite.rex.RexNode;
|
|||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.SqlFunctionCategory;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.ReturnTypes;
|
||||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.calcite.util.Optionality;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.bloom.BloomFilterAggregatorFactory;
|
||||
|
@ -38,6 +38,7 @@ import org.apache.druid.query.dimension.ExtractionDimensionSpec;
|
|||
import org.apache.druid.segment.column.ColumnType;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.expression.Expressions;
|
||||
import org.apache.druid.sql.calcite.planner.Calcites;
|
||||
|
@ -168,8 +169,6 @@ public class BloomFilterSqlAggregator implements SqlAggregator
|
|||
|
||||
private static class BloomFilterSqlAggFunction extends SqlAggFunction
|
||||
{
|
||||
private static final String SIGNATURE1 = "'" + NAME + "(column, maxNumEntries)'";
|
||||
|
||||
BloomFilterSqlAggFunction()
|
||||
{
|
||||
super(
|
||||
|
@ -178,13 +177,18 @@ public class BloomFilterSqlAggregator implements SqlAggregator
|
|||
SqlKind.OTHER_FUNCTION,
|
||||
ReturnTypes.explicit(SqlTypeName.OTHER),
|
||||
null,
|
||||
OperandTypes.and(
|
||||
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
|
||||
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
|
||||
),
|
||||
// Allow signatures like 'BLOOM_FILTER(column, maxNumEntries)'
|
||||
DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames("column", "maxNumEntries")
|
||||
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
|
||||
.literalOperands(1)
|
||||
.requiredOperandCount(2)
|
||||
.build(),
|
||||
SqlFunctionCategory.USER_DEFINED_FUNCTION,
|
||||
false,
|
||||
false
|
||||
false,
|
||||
Optionality.FORBIDDEN
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -125,7 +125,8 @@ public class BloomFilterSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
|
||||
testQuery(
|
||||
"SELECT\n"
|
||||
+ "BLOOM_FILTER(dim1, 1000)\n"
|
||||
+ "BLOOM_FILTER(dim1, 1000),\n"
|
||||
+ "BLOOM_FILTER(dim1, CAST(1000 AS INTEGER))\n"
|
||||
+ "FROM numfoo",
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
|
@ -145,7 +146,10 @@ public class BloomFilterSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{queryFramework().queryJsonMapper().writeValueAsString(expected1)}
|
||||
new Object[]{
|
||||
queryFramework().queryJsonMapper().writeValueAsString(expected1),
|
||||
queryFramework().queryJsonMapper().writeValueAsString(expected1)
|
||||
}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode;
|
|||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.SqlFunctionCategory;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.ReturnTypes;
|
||||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.calcite.util.Optionality;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogram;
|
||||
|
@ -39,6 +39,7 @@ import org.apache.druid.segment.column.ColumnType;
|
|||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregations;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.planner.PlannerContext;
|
||||
import org.apache.druid.sql.calcite.rel.InputAccessor;
|
||||
|
@ -221,15 +222,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
|
|||
|
||||
private static class FixedBucketsHistogramQuantileSqlAggFunction extends SqlAggFunction
|
||||
{
|
||||
private static final String SIGNATURE1 =
|
||||
"'"
|
||||
+ NAME
|
||||
+ "(column, probability, numBuckets, lowerLimit, upperLimit)'";
|
||||
private static final String SIGNATURE2 =
|
||||
"'"
|
||||
+ NAME
|
||||
+ "(column, probability, numBuckets, lowerLimit, upperLimit, outlierHandlingMode)'";
|
||||
|
||||
FixedBucketsHistogramQuantileSqlAggFunction()
|
||||
{
|
||||
super(
|
||||
|
@ -238,47 +230,33 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
|
|||
SqlKind.OTHER_FUNCTION,
|
||||
ReturnTypes.explicit(SqlTypeName.DOUBLE),
|
||||
null,
|
||||
OperandTypes.or(
|
||||
OperandTypes.and(
|
||||
OperandTypes.sequence(
|
||||
SIGNATURE1,
|
||||
OperandTypes.ANY,
|
||||
OperandTypes.LITERAL,
|
||||
OperandTypes.LITERAL,
|
||||
OperandTypes.LITERAL,
|
||||
OperandTypes.LITERAL
|
||||
),
|
||||
OperandTypes.family(
|
||||
SqlTypeFamily.ANY,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.NUMERIC
|
||||
)
|
||||
),
|
||||
OperandTypes.and(
|
||||
OperandTypes.sequence(
|
||||
SIGNATURE2,
|
||||
OperandTypes.ANY,
|
||||
OperandTypes.LITERAL,
|
||||
OperandTypes.LITERAL,
|
||||
OperandTypes.LITERAL,
|
||||
OperandTypes.LITERAL,
|
||||
OperandTypes.LITERAL
|
||||
),
|
||||
OperandTypes.family(
|
||||
SqlTypeFamily.ANY,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.STRING
|
||||
)
|
||||
// Allows signatures like 'APPROX_QUANTILE_FIXED_BUCKETS(column, probability, numBuckets, lowerLimit, upperLimit)'
|
||||
// and 'APPROX_QUANTILE_FIXED_BUCKETS(column, probability, numBuckets, lowerLimit, upperLimit, outlierHandlingMode)'
|
||||
DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames(
|
||||
"column",
|
||||
"probability",
|
||||
"numBuckets",
|
||||
"lowerLimit",
|
||||
"upperLimit",
|
||||
"outlierHandlingMode"
|
||||
)
|
||||
),
|
||||
.operandTypes(
|
||||
SqlTypeFamily.ANY,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.NUMERIC,
|
||||
SqlTypeFamily.STRING
|
||||
)
|
||||
.literalOperands(1, 2, 3, 4, 5)
|
||||
.requiredOperandCount(5)
|
||||
.build(),
|
||||
SqlFunctionCategory.NUMERIC,
|
||||
false,
|
||||
false
|
||||
false,
|
||||
Optionality.FORBIDDEN
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,10 +26,10 @@ import org.apache.calcite.rex.RexNode;
|
|||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.SqlFunctionCategory;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.ReturnTypes;
|
||||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.calcite.util.Optionality;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.histogram.ApproximateHistogram;
|
||||
|
@ -41,6 +41,7 @@ import org.apache.druid.segment.column.ValueType;
|
|||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregations;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.planner.PlannerContext;
|
||||
import org.apache.druid.sql.calcite.rel.InputAccessor;
|
||||
|
@ -196,9 +197,6 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
|
||||
private static class QuantileSqlAggFunction extends SqlAggFunction
|
||||
{
|
||||
private static final String SIGNATURE1 = "'" + NAME + "(column, probability)'";
|
||||
private static final String SIGNATURE2 = "'" + NAME + "(column, probability, resolution)'";
|
||||
|
||||
QuantileSqlAggFunction()
|
||||
{
|
||||
super(
|
||||
|
@ -207,19 +205,19 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
SqlKind.OTHER_FUNCTION,
|
||||
ReturnTypes.explicit(SqlTypeName.DOUBLE),
|
||||
null,
|
||||
OperandTypes.or(
|
||||
OperandTypes.and(
|
||||
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
|
||||
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
|
||||
),
|
||||
OperandTypes.and(
|
||||
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
|
||||
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
|
||||
)
|
||||
),
|
||||
// Checks for signatures like 'APPROX_QUANTILE(column, probability)' and
|
||||
// 'APPROX_QUANTILE(column, probability, resolution)'
|
||||
DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames("column", "probability", "resolution")
|
||||
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
|
||||
.requiredOperandCount(2)
|
||||
.literalOperands(1, 2)
|
||||
.build(),
|
||||
SqlFunctionCategory.NUMERIC,
|
||||
false,
|
||||
false
|
||||
false,
|
||||
Optionality.FORBIDDEN
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -128,6 +128,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ
|
|||
6.494999885559082,
|
||||
5.497499942779541,
|
||||
6.499499797821045,
|
||||
6.499499797821045,
|
||||
6.499499797821045,
|
||||
6.499499797821045,
|
||||
1.25
|
||||
}
|
||||
);
|
||||
|
@ -142,6 +145,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ
|
|||
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.99, 20, 0.0, 10.0) FILTER(WHERE dim1 = 'abc'),\n"
|
||||
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0) FILTER(WHERE dim1 <> 'abc'),\n"
|
||||
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0) FILTER(WHERE dim1 = 'abc'),\n"
|
||||
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'ignore') FILTER(WHERE dim1 = 'abc'),\n"
|
||||
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'clip') FILTER(WHERE dim1 = 'abc'),\n"
|
||||
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.999, 20, 0.0, 10.0, 'overflow') FILTER(WHERE dim1 = 'abc'),\n"
|
||||
+ "APPROX_QUANTILE_FIXED_BUCKETS(cnt, 0.5, 20, 0.0, 10.0)\n"
|
||||
+ "FROM foo",
|
||||
ImmutableList.of(
|
||||
|
@ -200,8 +206,32 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ
|
|||
),
|
||||
not(equality("dim1", "abc", ColumnType.STRING))
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new FixedBucketsHistogramAggregatorFactory(
|
||||
"a9:agg",
|
||||
"m1",
|
||||
20,
|
||||
0.0d,
|
||||
10.0d,
|
||||
FixedBucketsHistogram.OutlierHandlingMode.CLIP,
|
||||
false
|
||||
),
|
||||
equality("dim1", "abc", ColumnType.STRING)
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new FixedBucketsHistogramAggregatorFactory(
|
||||
"a10:agg",
|
||||
"m1",
|
||||
20,
|
||||
0.0d,
|
||||
10.0d,
|
||||
FixedBucketsHistogram.OutlierHandlingMode.OVERFLOW,
|
||||
false
|
||||
),
|
||||
equality("dim1", "abc", ColumnType.STRING)
|
||||
),
|
||||
new FixedBucketsHistogramAggregatorFactory(
|
||||
"a8:agg",
|
||||
"a11:agg",
|
||||
"cnt",
|
||||
20,
|
||||
0.0d,
|
||||
|
@ -219,7 +249,55 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQ
|
|||
new QuantilePostAggregator("a5", "a5:agg", 0.99f),
|
||||
new QuantilePostAggregator("a6", "a6:agg", 0.999f),
|
||||
new QuantilePostAggregator("a7", "a5:agg", 0.999f),
|
||||
new QuantilePostAggregator("a8", "a8:agg", 0.50f)
|
||||
new QuantilePostAggregator("a8", "a5:agg", 0.999f),
|
||||
new QuantilePostAggregator("a9", "a9:agg", 0.999f),
|
||||
new QuantilePostAggregator("a10", "a10:agg", 0.999f),
|
||||
new QuantilePostAggregator("a11", "a11:agg", 0.50f)
|
||||
)
|
||||
.context(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
expectedResults
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuantileWithCastedLiteralArguments()
|
||||
{
|
||||
final List<Object[]> expectedResults = ImmutableList.of(new Object[]{6.499499797821045});
|
||||
testQuery(
|
||||
"SELECT\n"
|
||||
+ "APPROX_QUANTILE_FIXED_BUCKETS("
|
||||
+ "m1, "
|
||||
+ "CAST(0.999 AS DOUBLE), "
|
||||
+ "CAST(20 AS INTEGER), "
|
||||
+ "CAST(0.0 AS DOUBLE), "
|
||||
+ "CAST(10.0 AS DOUBLE), "
|
||||
+ "CAST('overflow' AS VARCHAR)"
|
||||
+ ") "
|
||||
+ "FILTER(WHERE dim1 = 'abc')\n"
|
||||
+ "FROM foo",
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(ImmutableList.of(
|
||||
new FilteredAggregatorFactory(
|
||||
new FixedBucketsHistogramAggregatorFactory(
|
||||
"a0:agg",
|
||||
"m1",
|
||||
20,
|
||||
0.0d,
|
||||
10.0d,
|
||||
FixedBucketsHistogram.OutlierHandlingMode.OVERFLOW,
|
||||
false
|
||||
),
|
||||
equality("dim1", "abc", ColumnType.STRING)
|
||||
)
|
||||
))
|
||||
.postAggregators(
|
||||
new QuantilePostAggregator("a0", "a0:agg", 0.999f)
|
||||
)
|
||||
.context(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
|
|
|
@ -121,6 +121,7 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
"SELECT\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.01),\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.5, 50),\n"
|
||||
+ "APPROX_QUANTILE(m1, CAST(0.5 AS DOUBLE), CAST(50 AS INTEGER)),\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.98, 200),\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.99),\n"
|
||||
+ "APPROX_QUANTILE(m1 * 2, 0.97),\n"
|
||||
|
@ -144,28 +145,29 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
)
|
||||
.aggregators(ImmutableList.of(
|
||||
new ApproximateHistogramAggregatorFactory("a0:agg", "m1", null, null, null, null, false),
|
||||
new ApproximateHistogramAggregatorFactory("a2:agg", "m1", 200, null, null, null, false),
|
||||
new ApproximateHistogramAggregatorFactory("a4:agg", "v0", null, null, null, null, false),
|
||||
new ApproximateHistogramAggregatorFactory("a3:agg", "m1", 200, null, null, null, false),
|
||||
new ApproximateHistogramAggregatorFactory("a5:agg", "v0", null, null, null, null, false),
|
||||
new FilteredAggregatorFactory(
|
||||
new ApproximateHistogramAggregatorFactory("a5:agg", "m1", null, null, null, null, false),
|
||||
new ApproximateHistogramAggregatorFactory("a6:agg", "m1", null, null, null, null, false),
|
||||
equality("dim1", "abc", ColumnType.STRING)
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new ApproximateHistogramAggregatorFactory("a6:agg", "m1", null, null, null, null, false),
|
||||
new ApproximateHistogramAggregatorFactory("a7:agg", "m1", null, null, null, null, false),
|
||||
not(equality("dim1", "abc", ColumnType.STRING))
|
||||
),
|
||||
new ApproximateHistogramAggregatorFactory("a8:agg", "cnt", null, null, null, null, false)
|
||||
new ApproximateHistogramAggregatorFactory("a9:agg", "cnt", null, null, null, null, false)
|
||||
))
|
||||
.postAggregators(
|
||||
new QuantilePostAggregator("a0", "a0:agg", 0.01f),
|
||||
new QuantilePostAggregator("a1", "a0:agg", 0.50f),
|
||||
new QuantilePostAggregator("a2", "a2:agg", 0.98f),
|
||||
new QuantilePostAggregator("a3", "a0:agg", 0.99f),
|
||||
new QuantilePostAggregator("a4", "a4:agg", 0.97f),
|
||||
new QuantilePostAggregator("a5", "a5:agg", 0.99f),
|
||||
new QuantilePostAggregator("a6", "a6:agg", 0.999f),
|
||||
new QuantilePostAggregator("a7", "a5:agg", 0.999f),
|
||||
new QuantilePostAggregator("a8", "a8:agg", 0.50f)
|
||||
new QuantilePostAggregator("a2", "a0:agg", 0.50f),
|
||||
new QuantilePostAggregator("a3", "a3:agg", 0.98f),
|
||||
new QuantilePostAggregator("a4", "a0:agg", 0.99f),
|
||||
new QuantilePostAggregator("a5", "a5:agg", 0.97f),
|
||||
new QuantilePostAggregator("a6", "a6:agg", 0.99f),
|
||||
new QuantilePostAggregator("a7", "a7:agg", 0.999f),
|
||||
new QuantilePostAggregator("a8", "a6:agg", 0.999f),
|
||||
new QuantilePostAggregator("a9", "a9:agg", 0.50f)
|
||||
)
|
||||
.context(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
|
@ -174,6 +176,7 @@ public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
new Object[]{
|
||||
1.0,
|
||||
3.0,
|
||||
3.0,
|
||||
5.880000114440918,
|
||||
5.940000057220459,
|
||||
11.640000343322754,
|
||||
|
|
|
@ -33,8 +33,8 @@ import org.apache.calcite.sql.SqlNode;
|
|||
import org.apache.calcite.sql.SqlOperatorBinding;
|
||||
import org.apache.calcite.sql.parser.SqlParserPos;
|
||||
import org.apache.calcite.sql.type.InferTypes;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.SqlReturnTypeInference;
|
||||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.calcite.sql.type.SqlTypeUtil;
|
||||
import org.apache.calcite.sql.util.SqlVisitor;
|
||||
|
@ -61,6 +61,7 @@ import org.apache.druid.segment.column.ColumnHolder;
|
|||
import org.apache.druid.segment.column.ColumnType;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.expression.Expressions;
|
||||
import org.apache.druid.sql.calcite.planner.Calcites;
|
||||
|
@ -369,14 +370,13 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
SqlKind.OTHER_FUNCTION,
|
||||
EARLIEST_LATEST_ARG0_RETURN_TYPE_INFERENCE,
|
||||
InferTypes.RETURN_TYPE,
|
||||
OperandTypes.or(
|
||||
OperandTypes.ANY,
|
||||
OperandTypes.sequence(
|
||||
"'" + aggregatorType.name() + "(expr, maxBytesPerString)'",
|
||||
OperandTypes.ANY,
|
||||
OperandTypes.and(OperandTypes.NUMERIC, OperandTypes.LITERAL)
|
||||
)
|
||||
),
|
||||
DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames("expr", "maxBytesPerString")
|
||||
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
|
||||
.requiredOperandCount(1)
|
||||
.literalOperands(1)
|
||||
.build(),
|
||||
SqlFunctionCategory.USER_DEFINED_FUNCTION,
|
||||
false,
|
||||
false,
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.apache.calcite.sql.SqlAggFunction;
|
|||
import org.apache.calcite.sql.SqlFunctionCategory;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.type.InferTypes;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.SqlReturnTypeInference;
|
||||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.util.Optionality;
|
||||
|
@ -38,6 +37,7 @@ import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregat
|
|||
import org.apache.druid.segment.column.ColumnType;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.expression.Expressions;
|
||||
import org.apache.druid.sql.calcite.planner.Calcites;
|
||||
|
@ -176,19 +176,13 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator
|
|||
SqlKind.OTHER_FUNCTION,
|
||||
EARLIEST_LATEST_ARG0_RETURN_TYPE_INFERENCE,
|
||||
InferTypes.RETURN_TYPE,
|
||||
OperandTypes.or(
|
||||
OperandTypes.sequence(
|
||||
"'" + StringUtils.format("%s_BY", aggregatorType.name()) + "(expr, timeColumn)'",
|
||||
OperandTypes.ANY,
|
||||
OperandTypes.family(SqlTypeFamily.TIMESTAMP)
|
||||
),
|
||||
OperandTypes.sequence(
|
||||
"'" + StringUtils.format("%s_BY", aggregatorType.name()) + "(expr, timeColumn, maxBytesPerString)'",
|
||||
OperandTypes.ANY,
|
||||
OperandTypes.family(SqlTypeFamily.TIMESTAMP),
|
||||
OperandTypes.and(OperandTypes.NUMERIC, OperandTypes.LITERAL)
|
||||
)
|
||||
),
|
||||
DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames("expr", "timeColumn", "maxBytesPerString")
|
||||
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.NUMERIC)
|
||||
.requiredOperandCount(2)
|
||||
.literalOperands(2)
|
||||
.build(),
|
||||
SqlFunctionCategory.USER_DEFINED_FUNCTION,
|
||||
false,
|
||||
false,
|
||||
|
|
|
@ -24,7 +24,6 @@ import it.unimi.dsi.fastutil.ints.IntArraySet;
|
|||
import it.unimi.dsi.fastutil.ints.IntSet;
|
||||
import it.unimi.dsi.fastutil.ints.IntSets;
|
||||
import org.apache.calcite.rel.type.RelDataType;
|
||||
import org.apache.calcite.runtime.CalciteException;
|
||||
import org.apache.calcite.sql.SqlCallBinding;
|
||||
import org.apache.calcite.sql.SqlNode;
|
||||
import org.apache.calcite.sql.SqlOperandCountRange;
|
||||
|
@ -35,35 +34,47 @@ import org.apache.calcite.sql.type.SqlOperandTypeChecker;
|
|||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.calcite.util.Static;
|
||||
import org.apache.druid.java.util.common.ISE;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Operand type checker that is used in simple situations: there are a particular number of operands, with
|
||||
* Operand type checker that is used in 'simple' situations: there are a particular number of operands, with
|
||||
* particular types, some of which may be optional or nullable, and some of which may be required to be literals.
|
||||
*/
|
||||
public class BasicOperandTypeChecker implements SqlOperandTypeChecker
|
||||
public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
|
||||
{
|
||||
/**
|
||||
* Operand names for {@link #getAllowedSignatures(SqlOperator, String)}. May be empty, in which case the
|
||||
* {@link #operandTypes} are used instead.
|
||||
*/
|
||||
private final List<String> operandNames;
|
||||
private final List<SqlTypeFamily> operandTypes;
|
||||
private final int requiredOperands;
|
||||
private final IntSet nullOperands;
|
||||
private final IntSet nullableOperands;
|
||||
private final IntSet literalOperands;
|
||||
|
||||
BasicOperandTypeChecker(
|
||||
private DefaultOperandTypeChecker(
|
||||
final List<String> operandNames,
|
||||
final List<SqlTypeFamily> operandTypes,
|
||||
final int requiredOperands,
|
||||
final IntSet nullOperands,
|
||||
final IntSet nullableOperands,
|
||||
@Nullable final int[] literalOperands
|
||||
)
|
||||
{
|
||||
Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0);
|
||||
this.operandNames = Preconditions.checkNotNull(operandNames, "operandNames");
|
||||
this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes");
|
||||
this.requiredOperands = requiredOperands;
|
||||
this.nullOperands = Preconditions.checkNotNull(nullOperands, "nullOperands");
|
||||
this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands");
|
||||
|
||||
if (!operandNames.isEmpty() && operandNames.size() != operandTypes.size()) {
|
||||
throw new ISE("Operand name count[%s] and type count[%s] must match", operandNames.size(), operandTypes.size());
|
||||
}
|
||||
|
||||
if (literalOperands == null) {
|
||||
this.literalOperands = IntSets.EMPTY_SET;
|
||||
|
@ -78,19 +89,6 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker
|
|||
return new Builder();
|
||||
}
|
||||
|
||||
public static boolean throwOrReturn(
|
||||
final boolean throwOnFailure,
|
||||
final SqlCallBinding callBinding,
|
||||
final Function<SqlCallBinding, CalciteException> exceptionMapper
|
||||
)
|
||||
{
|
||||
if (throwOnFailure) {
|
||||
throw exceptionMapper.apply(callBinding);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure)
|
||||
{
|
||||
|
@ -98,9 +96,9 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker
|
|||
final SqlNode operand = callBinding.operands().get(i);
|
||||
|
||||
if (literalOperands.contains(i)) {
|
||||
// Verify that 'operand' is a literal.
|
||||
if (!SqlUtil.isLiteral(operand)) {
|
||||
return throwOrReturn(
|
||||
// Verify that 'operand' is a literal. Allow CAST, since we can reduce these away later.
|
||||
if (!SqlUtil.isLiteral(operand, true)) {
|
||||
return OperatorConversions.throwOrReturn(
|
||||
throwOnFailure,
|
||||
callBinding,
|
||||
cb -> cb.getValidator()
|
||||
|
@ -121,15 +119,15 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker
|
|||
// Operand came in with one of the expected types.
|
||||
} else if (operandType.getSqlTypeName() == SqlTypeName.NULL || SqlUtil.isNullLiteral(operand, true)) {
|
||||
// Null came in, check if operand is a nullable type.
|
||||
if (!nullOperands.contains(i)) {
|
||||
return throwOrReturn(
|
||||
if (!nullableOperands.contains(i)) {
|
||||
return OperatorConversions.throwOrReturn(
|
||||
throwOnFailure,
|
||||
callBinding,
|
||||
cb -> cb.getValidator().newValidationError(operand, Static.RESOURCE.nullIllegal())
|
||||
);
|
||||
}
|
||||
} else {
|
||||
return throwOrReturn(
|
||||
return OperatorConversions.throwOrReturn(
|
||||
throwOnFailure,
|
||||
callBinding,
|
||||
SqlCallBinding::newValidationSignatureError
|
||||
|
@ -149,7 +147,25 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker
|
|||
@Override
|
||||
public String getAllowedSignatures(SqlOperator op, String opName)
|
||||
{
|
||||
return SqlUtil.getAliasedSignature(op, opName, operandTypes);
|
||||
final List<?> operands = !operandNames.isEmpty() ? operandNames : operandTypes;
|
||||
final StringBuilder ret = new StringBuilder();
|
||||
ret.append("'");
|
||||
ret.append(opName);
|
||||
ret.append("(");
|
||||
for (int i = 0; i < operands.size(); i++) {
|
||||
if (i > 0) {
|
||||
ret.append(", ");
|
||||
}
|
||||
if (i >= requiredOperands) {
|
||||
ret.append("[");
|
||||
}
|
||||
ret.append("<").append(operands.get(i)).append(">");
|
||||
}
|
||||
for (int i = requiredOperands; i < operands.size(); i++) {
|
||||
ret.append("]");
|
||||
}
|
||||
ret.append(")'");
|
||||
return ret.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -166,64 +182,70 @@ public class BasicOperandTypeChecker implements SqlOperandTypeChecker
|
|||
|
||||
public static class Builder
|
||||
{
|
||||
private List<String> operandNames = Collections.emptyList();
|
||||
private List<SqlTypeFamily> operandTypes;
|
||||
private Integer requiredOperandCount = null;
|
||||
private int[] literalOperands = null;
|
||||
|
||||
/**
|
||||
* Signifies that a function accepts operands of type family given by {@param operandTypes}.
|
||||
*/
|
||||
@Nullable
|
||||
private Integer requiredOperandCount;
|
||||
private int[] literalOperands;
|
||||
|
||||
private Builder()
|
||||
{
|
||||
}
|
||||
|
||||
public Builder operandNames(final String... operandNames)
|
||||
{
|
||||
this.operandNames = Arrays.asList(operandNames);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder operandNames(final List<String> operandNames)
|
||||
{
|
||||
this.operandNames = operandNames;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder operandTypes(final SqlTypeFamily... operandTypes)
|
||||
{
|
||||
this.operandTypes = Arrays.asList(operandTypes);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Signifies that a function accepts operands of type family given by {@param operandTypes}.
|
||||
*/
|
||||
public Builder operandTypes(final List<SqlTypeFamily> operandTypes)
|
||||
{
|
||||
this.operandTypes = operandTypes;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Signifies that the first {@code requiredOperands} operands are required, and all later operands are optional.
|
||||
*
|
||||
* Required operands are not allowed to be null. Optional operands can either be skipped or explicitly provided as
|
||||
* literal NULLs. For example, if {@code requiredOperands == 1}, then {@code F(x, NULL)} and {@code F(x)} are both
|
||||
* accepted, and {@code x} must not be null.
|
||||
*/
|
||||
public Builder requiredOperandCount(final int requiredOperandCount)
|
||||
public Builder requiredOperandCount(Integer requiredOperandCount)
|
||||
{
|
||||
this.requiredOperandCount = requiredOperandCount;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Signifies that the operands at positions given by {@code literalOperands} must be literals.
|
||||
*/
|
||||
public Builder literalOperands(final int... literalOperands)
|
||||
{
|
||||
this.literalOperands = literalOperands;
|
||||
return this;
|
||||
}
|
||||
|
||||
public BasicOperandTypeChecker build()
|
||||
public DefaultOperandTypeChecker build()
|
||||
{
|
||||
// Create "nullableOperands" set including all optional arguments.
|
||||
final IntSet nullableOperands = new IntArraySet();
|
||||
if (requiredOperandCount != null) {
|
||||
IntStream.range(requiredOperandCount, operandTypes.size()).forEach(nullableOperands::add);
|
||||
}
|
||||
|
||||
return new BasicOperandTypeChecker(
|
||||
int computedRequiredOperandCount = requiredOperandCount == null ? operandTypes.size() : requiredOperandCount;
|
||||
return new DefaultOperandTypeChecker(
|
||||
operandNames,
|
||||
operandTypes,
|
||||
requiredOperandCount == null ? operandTypes.size() : requiredOperandCount,
|
||||
nullableOperands,
|
||||
computedRequiredOperandCount,
|
||||
DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size()),
|
||||
literalOperands
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount)
|
||||
{
|
||||
final IntSet nullableOperands = new IntArraySet();
|
||||
IntStream.range(requiredOperandCount, totalOperandCount).forEach(nullableOperands::add);
|
||||
return nullableOperands;
|
||||
}
|
||||
}
|
|
@ -19,13 +19,11 @@
|
|||
|
||||
package org.apache.druid.sql.calcite.expression;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Iterables;
|
||||
import it.unimi.dsi.fastutil.ints.IntArraySet;
|
||||
import it.unimi.dsi.fastutil.ints.IntSet;
|
||||
import it.unimi.dsi.fastutil.ints.IntSets;
|
||||
import org.apache.calcite.rel.type.RelDataType;
|
||||
import org.apache.calcite.rex.RexCall;
|
||||
import org.apache.calcite.rex.RexInputRef;
|
||||
|
@ -37,13 +35,9 @@ import org.apache.calcite.sql.SqlCallBinding;
|
|||
import org.apache.calcite.sql.SqlFunction;
|
||||
import org.apache.calcite.sql.SqlFunctionCategory;
|
||||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.SqlNode;
|
||||
import org.apache.calcite.sql.SqlOperandCountRange;
|
||||
import org.apache.calcite.sql.SqlOperator;
|
||||
import org.apache.calcite.sql.SqlUtil;
|
||||
import org.apache.calcite.sql.type.BasicSqlType;
|
||||
import org.apache.calcite.sql.type.ReturnTypes;
|
||||
import org.apache.calcite.sql.type.SqlOperandCountRanges;
|
||||
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
|
||||
import org.apache.calcite.sql.type.SqlOperandTypeInference;
|
||||
import org.apache.calcite.sql.type.SqlReturnTypeInference;
|
||||
|
@ -51,7 +45,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
|
|||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.calcite.sql.type.SqlTypeTransforms;
|
||||
import org.apache.calcite.util.Optionality;
|
||||
import org.apache.calcite.util.Static;
|
||||
import org.apache.druid.error.DruidException;
|
||||
import org.apache.druid.java.util.common.IAE;
|
||||
import org.apache.druid.java.util.common.ISE;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
|
@ -69,7 +63,6 @@ import java.util.Arrays;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Utilities for assisting in writing {@link SqlOperatorConversion} implementations.
|
||||
|
@ -334,6 +327,13 @@ public class OperatorConversions
|
|||
return new AggregatorBuilder(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helps in creating the operator builder along with validations and type inference for simple operator conversions.
|
||||
*
|
||||
* The type checker for this operator conversion can be either supplied manually, or an instance of
|
||||
* {@link DefaultOperandTypeChecker} will be used if the user passes in {@link #operandTypes} and other optional
|
||||
* parameters. Exactly one of them must be supplied to the builder.
|
||||
*/
|
||||
public static class OperatorBuilder<T extends SqlFunction>
|
||||
{
|
||||
protected final String name;
|
||||
|
@ -491,8 +491,8 @@ public class OperatorConversions
|
|||
}
|
||||
|
||||
/**
|
||||
* Equivalent to calling {@link BasicOperandTypeChecker.Builder#operandTypes(SqlTypeFamily...)}; leads to using a
|
||||
* {@link BasicOperandTypeChecker} as our operand type checker.
|
||||
* Equivalent to calling {@link DefaultOperandTypeChecker.Builder#operandTypes(SqlTypeFamily...)}; leads to using a
|
||||
* {@link DefaultOperandTypeChecker} as our operand type checker.
|
||||
*
|
||||
* May be used in conjunction with {@link #requiredOperandCount(int)} and {@link #literalOperands(int...)} in order
|
||||
* to further refine operand checking logic.
|
||||
|
@ -506,12 +506,9 @@ public class OperatorConversions
|
|||
}
|
||||
|
||||
/**
|
||||
* Equivalent to calling {@link BasicOperandTypeChecker.Builder#requiredOperandCount(int)}; leads to using a
|
||||
* {@link BasicOperandTypeChecker} as our operand type checker.
|
||||
*
|
||||
* Not compatible with {@link #operandTypeChecker(SqlOperandTypeChecker)}.
|
||||
* Equivalent to calling {@link DefaultOperandTypeChecker.Builder#requiredOperandCount(Integer)}; leads to using a
|
||||
* {@link DefaultOperandTypeChecker} as our operand type checker.
|
||||
*/
|
||||
@Deprecated
|
||||
public OperatorBuilder<T> requiredOperandCount(final int requiredOperandCount)
|
||||
{
|
||||
this.requiredOperandCount = requiredOperandCount;
|
||||
|
@ -531,8 +528,8 @@ public class OperatorConversions
|
|||
}
|
||||
|
||||
/**
|
||||
* Equivalent to calling {@link BasicOperandTypeChecker.Builder#literalOperands(int...)}; leads to using a
|
||||
* {@link BasicOperandTypeChecker} as our operand type checker.
|
||||
* Equivalent to calling {@link DefaultOperandTypeChecker.Builder#literalOperands(int...)}; leads to using a
|
||||
* {@link DefaultOperandTypeChecker} as our operand type checker.
|
||||
*
|
||||
* Not compatible with {@link #operandTypeChecker(SqlOperandTypeChecker)}.
|
||||
*/
|
||||
|
@ -554,37 +551,30 @@ public class OperatorConversions
|
|||
@SuppressWarnings("unchecked")
|
||||
public T build()
|
||||
{
|
||||
final IntSet nullableOperands = buildNullableOperands();
|
||||
return (T) new SqlFunction(
|
||||
name,
|
||||
kind,
|
||||
Preconditions.checkNotNull(returnTypeInference, "returnTypeInference"),
|
||||
buildOperandTypeInference(nullableOperands),
|
||||
buildOperandTypeChecker(nullableOperands),
|
||||
buildOperandTypeInference(),
|
||||
buildOperandTypeChecker(),
|
||||
functionCategory
|
||||
);
|
||||
}
|
||||
|
||||
protected IntSet buildNullableOperands()
|
||||
{
|
||||
// Create "nullableOperands" set including all optional arguments.
|
||||
final IntSet nullableOperands = new IntArraySet();
|
||||
if (requiredOperandCount != null) {
|
||||
IntStream.range(requiredOperandCount, operandTypes.size()).forEach(nullableOperands::add);
|
||||
}
|
||||
return nullableOperands;
|
||||
}
|
||||
|
||||
protected SqlOperandTypeChecker buildOperandTypeChecker(final IntSet nullableOperands)
|
||||
protected SqlOperandTypeChecker buildOperandTypeChecker()
|
||||
{
|
||||
if (operandTypeChecker == null) {
|
||||
return new DefaultOperandTypeChecker(
|
||||
operandNames,
|
||||
operandTypes,
|
||||
requiredOperandCount == null ? operandTypes.size() : requiredOperandCount,
|
||||
nullableOperands,
|
||||
literalOperands
|
||||
);
|
||||
if (operandTypes == null) {
|
||||
throw DruidException.defensive(
|
||||
"'operandTypes' must be non null if 'operandTypeChecker' is not passed to the operator conversion.");
|
||||
}
|
||||
return DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames(operandNames)
|
||||
.operandTypes(operandTypes)
|
||||
.requiredOperandCount(requiredOperandCount == null ? operandTypes.size() : requiredOperandCount)
|
||||
.literalOperands(literalOperands)
|
||||
.build();
|
||||
} else if (operandNames.isEmpty()
|
||||
&& operandTypes == null
|
||||
&& requiredOperandCount == null
|
||||
|
@ -598,8 +588,11 @@ public class OperatorConversions
|
|||
}
|
||||
}
|
||||
|
||||
protected SqlOperandTypeInference buildOperandTypeInference(final IntSet nullableOperands)
|
||||
protected SqlOperandTypeInference buildOperandTypeInference()
|
||||
{
|
||||
final IntSet nullableOperands = requiredOperandCount == null
|
||||
? new IntArraySet()
|
||||
: DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size());
|
||||
if (operandTypeInference == null) {
|
||||
SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands);
|
||||
return (callBinding, returnType, types) -> {
|
||||
|
@ -634,9 +627,8 @@ public class OperatorConversions
|
|||
@Override
|
||||
public SqlAggFunction build()
|
||||
{
|
||||
final IntSet nullableOperands = buildNullableOperands();
|
||||
final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference(nullableOperands);
|
||||
final SqlOperandTypeChecker operandTypeChecker = buildOperandTypeChecker(nullableOperands);
|
||||
final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference();
|
||||
final SqlOperandTypeChecker operandTypeChecker = buildOperandTypeChecker();
|
||||
|
||||
class DruidSqlAggFunction extends SqlAggFunction
|
||||
{
|
||||
|
@ -735,147 +727,6 @@ public class OperatorConversions
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Operand type checker that is used in 'simple' situations: there are a particular number of operands, with
|
||||
* particular types, some of which may be optional or nullable, and some of which may be required to be literals.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
static class DefaultOperandTypeChecker implements SqlOperandTypeChecker
|
||||
{
|
||||
/**
|
||||
* Operand names for {@link #getAllowedSignatures(SqlOperator, String)}. May be empty, in which case the
|
||||
* {@link #operandTypes} are used instead.
|
||||
*/
|
||||
private final List<String> operandNames;
|
||||
private final List<SqlTypeFamily> operandTypes;
|
||||
private final int requiredOperands;
|
||||
private final IntSet nullableOperands;
|
||||
private final IntSet literalOperands;
|
||||
|
||||
public int getNumberOfLiteralOperands()
|
||||
{
|
||||
return literalOperands.size();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
DefaultOperandTypeChecker(
|
||||
final List<String> operandNames,
|
||||
final List<SqlTypeFamily> operandTypes,
|
||||
final int requiredOperands,
|
||||
final IntSet nullableOperands,
|
||||
@Nullable final int[] literalOperands
|
||||
)
|
||||
{
|
||||
Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0);
|
||||
this.operandNames = Preconditions.checkNotNull(operandNames, "operandNames");
|
||||
this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes");
|
||||
this.requiredOperands = requiredOperands;
|
||||
this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands");
|
||||
|
||||
if (!operandNames.isEmpty() && operandNames.size() != operandTypes.size()) {
|
||||
throw new ISE("Operand name count[%s] and type count[%s] must match", operandNames.size(), operandTypes.size());
|
||||
}
|
||||
|
||||
if (literalOperands == null) {
|
||||
this.literalOperands = IntSets.EMPTY_SET;
|
||||
} else {
|
||||
this.literalOperands = new IntArraySet();
|
||||
Arrays.stream(literalOperands).forEach(this.literalOperands::add);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure)
|
||||
{
|
||||
for (int i = 0; i < callBinding.operands().size(); i++) {
|
||||
final SqlNode operand = callBinding.operands().get(i);
|
||||
|
||||
if (literalOperands.contains(i)) {
|
||||
// Verify that 'operand' is a literal. Allow CAST, since we can reduce these away later.
|
||||
if (!SqlUtil.isLiteral(operand, true)) {
|
||||
return throwOrReturn(
|
||||
throwOnFailure,
|
||||
callBinding,
|
||||
cb -> cb.getValidator()
|
||||
.newValidationError(
|
||||
operand,
|
||||
Static.RESOURCE.argumentMustBeLiteral(callBinding.getOperator().getName())
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
final RelDataType operandType = callBinding.getValidator().deriveType(callBinding.getScope(), operand);
|
||||
final SqlTypeFamily expectedFamily = operandTypes.get(i);
|
||||
|
||||
if (expectedFamily == SqlTypeFamily.ANY) {
|
||||
// ANY matches anything. This operand is all good; do nothing.
|
||||
} else if (expectedFamily.getTypeNames().contains(operandType.getSqlTypeName())) {
|
||||
// Operand came in with one of the expected types.
|
||||
} else if (operandType.getSqlTypeName() == SqlTypeName.NULL || SqlUtil.isNullLiteral(operand, true)) {
|
||||
// Null came in, check if operand is a nullable type.
|
||||
if (!nullableOperands.contains(i)) {
|
||||
return throwOrReturn(
|
||||
throwOnFailure,
|
||||
callBinding,
|
||||
cb -> cb.getValidator().newValidationError(operand, Static.RESOURCE.nullIllegal())
|
||||
);
|
||||
}
|
||||
} else {
|
||||
return throwOrReturn(
|
||||
throwOnFailure,
|
||||
callBinding,
|
||||
SqlCallBinding::newValidationSignatureError
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SqlOperandCountRange getOperandCountRange()
|
||||
{
|
||||
return SqlOperandCountRanges.between(requiredOperands, operandTypes.size());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getAllowedSignatures(SqlOperator op, String opName)
|
||||
{
|
||||
final List<?> operands = !operandNames.isEmpty() ? operandNames : operandTypes;
|
||||
final StringBuilder ret = new StringBuilder();
|
||||
ret.append("'");
|
||||
ret.append(opName);
|
||||
ret.append("(");
|
||||
for (int i = 0; i < operands.size(); i++) {
|
||||
if (i > 0) {
|
||||
ret.append(", ");
|
||||
}
|
||||
if (i >= requiredOperands) {
|
||||
ret.append("[");
|
||||
}
|
||||
ret.append("<").append(operands.get(i)).append(">");
|
||||
}
|
||||
for (int i = requiredOperands; i < operands.size(); i++) {
|
||||
ret.append("]");
|
||||
}
|
||||
ret.append(")'");
|
||||
return ret.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Consistency getConsistency()
|
||||
{
|
||||
return Consistency.NONE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isOptional(int i)
|
||||
{
|
||||
return i + 1 > requiredOperands;
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean throwOrReturn(
|
||||
final boolean throwOnFailure,
|
||||
final SqlCallBinding callBinding,
|
||||
|
|
|
@ -23,7 +23,6 @@ import org.apache.calcite.rex.RexNode;
|
|||
import org.apache.calcite.sql.SqlFunction;
|
||||
import org.apache.calcite.sql.SqlFunctionCategory;
|
||||
import org.apache.calcite.sql.SqlOperator;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.SqlReturnTypeInference;
|
||||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
|
@ -52,13 +51,10 @@ public class ComplexDecodeBase64OperatorConversion implements SqlOperatorConvers
|
|||
|
||||
private static final SqlFunction SQL_FUNCTION = OperatorConversions
|
||||
.operatorBuilder(StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME))
|
||||
.operandTypeChecker(
|
||||
OperandTypes.sequence(
|
||||
"'" + StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME) + "(typeName, base64)'",
|
||||
OperandTypes.and(OperandTypes.family(SqlTypeFamily.STRING), OperandTypes.LITERAL),
|
||||
OperandTypes.ANY
|
||||
)
|
||||
)
|
||||
.operandNames("typeName", "base64")
|
||||
.operandTypes(SqlTypeFamily.STRING, SqlTypeFamily.ANY)
|
||||
.requiredOperandCount(2)
|
||||
.literalOperands(0)
|
||||
.returnTypeInference(ARBITRARY_COMPLEX_RETURN_TYPE_INFERENCE)
|
||||
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
|
||||
.build();
|
||||
|
|
|
@ -118,13 +118,10 @@ public class NestedDataOperatorConversions
|
|||
{
|
||||
private static final SqlFunction SQL_FUNCTION = OperatorConversions
|
||||
.operatorBuilder("JSON_KEYS")
|
||||
.operandTypeChecker(
|
||||
OperandTypes.sequence(
|
||||
"'JSON_KEYS(expr, path)'",
|
||||
OperandTypes.ANY,
|
||||
OperandTypes.and(OperandTypes.family(SqlTypeFamily.STRING), OperandTypes.LITERAL)
|
||||
)
|
||||
)
|
||||
.operandNames("expr", "path")
|
||||
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.STRING)
|
||||
.literalOperands(1)
|
||||
.requiredOperandCount(2)
|
||||
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
|
||||
.returnTypeNullableArrayWithNullableElements(SqlTypeName.VARCHAR)
|
||||
.build();
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.apache.druid.math.expr.Expr;
|
|||
import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider;
|
||||
import org.apache.druid.query.lookup.RegisteredLookupExtractionFn;
|
||||
import org.apache.druid.segment.column.RowSignature;
|
||||
import org.apache.druid.sql.calcite.expression.BasicOperandTypeChecker;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.expression.OperatorConversions;
|
||||
import org.apache.druid.sql.calcite.expression.SqlOperatorConversion;
|
||||
|
@ -43,16 +42,10 @@ public class QueryLookupOperatorConversion implements SqlOperatorConversion
|
|||
{
|
||||
private static final SqlFunction SQL_FUNCTION = OperatorConversions
|
||||
.operatorBuilder("LOOKUP")
|
||||
.operandTypeChecker(
|
||||
BasicOperandTypeChecker.builder()
|
||||
.operandTypes(
|
||||
SqlTypeFamily.CHARACTER,
|
||||
SqlTypeFamily.CHARACTER,
|
||||
SqlTypeFamily.CHARACTER
|
||||
)
|
||||
.requiredOperandCount(2)
|
||||
.literalOperands(2)
|
||||
.build())
|
||||
.operandNames("expr", "lookupName", "replaceMissingValueWith")
|
||||
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
|
||||
.requiredOperandCount(2)
|
||||
.literalOperands(2)
|
||||
.returnTypeNullable(SqlTypeName.VARCHAR)
|
||||
.functionCategory(SqlFunctionCategory.STRING)
|
||||
.build();
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.apache.calcite.sql.SqlFunctionCategory;
|
|||
import org.apache.calcite.sql.SqlOperator;
|
||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
||||
import org.apache.calcite.sql.parser.SqlParserPos;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.calcite.sql2rel.SqlRexContext;
|
||||
|
@ -53,13 +52,10 @@ public class TimeInIntervalConvertletFactory implements DruidConvertletFactory
|
|||
|
||||
private static final SqlOperator OPERATOR = OperatorConversions
|
||||
.operatorBuilder(NAME)
|
||||
.operandTypeChecker(
|
||||
OperandTypes.sequence(
|
||||
"'" + NAME + "(<TIMESTAMP>, <LITERAL ISO8601 INTERVAL>)'",
|
||||
OperandTypes.family(SqlTypeFamily.TIMESTAMP),
|
||||
OperandTypes.and(OperandTypes.family(SqlTypeFamily.CHARACTER), OperandTypes.LITERAL)
|
||||
)
|
||||
)
|
||||
.operandNames("timestamp", "interval")
|
||||
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)
|
||||
.requiredOperandCount(2)
|
||||
.literalOperands(1)
|
||||
.returnTypeNonNull(SqlTypeName.BOOLEAN)
|
||||
.functionCategory(SqlFunctionCategory.TIMEDATE)
|
||||
.build();
|
||||
|
|
|
@ -4454,6 +4454,48 @@ public class CalciteNestedDataQueryTest extends BaseCalciteQueryTest
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGroupByCastedRootKeysJsonPath()
|
||||
{
|
||||
cannotVectorize();
|
||||
testQuery(
|
||||
"SELECT "
|
||||
+ "JSON_KEYS(nester, CAST('$.' AS VARCHAR)), "
|
||||
+ "SUM(cnt) "
|
||||
+ "FROM druid.nested GROUP BY 1",
|
||||
ImmutableList.of(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(DATA_SOURCE)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setVirtualColumns(
|
||||
new ExpressionVirtualColumn(
|
||||
"v0",
|
||||
"json_keys(\"nester\",'$.')",
|
||||
ColumnType.STRING_ARRAY,
|
||||
queryFramework().macroTable()
|
||||
)
|
||||
)
|
||||
.setDimensions(
|
||||
dimensions(
|
||||
new DefaultDimensionSpec("v0", "d0", ColumnType.STRING_ARRAY)
|
||||
)
|
||||
)
|
||||
.setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{null, 5L},
|
||||
new Object[]{"[\"array\",\"n\"]", 2L}
|
||||
),
|
||||
RowSignature.builder()
|
||||
.add("EXPR$0", ColumnType.STRING_ARRAY)
|
||||
.add("EXPR$1", ColumnType.LONG)
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGroupByAllPaths()
|
||||
{
|
||||
|
|
|
@ -642,10 +642,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
|
||||
testQuery(
|
||||
"SELECT "
|
||||
+ "EARLIEST(cnt), EARLIEST(m1), EARLIEST(dim1, 10), "
|
||||
+ "EARLIEST(cnt), EARLIEST(m1), EARLIEST(dim1, 10), EARLIEST(dim1, CAST(10 AS INTEGER)), "
|
||||
+ "EARLIEST(cnt + 1), EARLIEST(m1 + 1), EARLIEST(dim1 || CAST(cnt AS VARCHAR), 10), "
|
||||
+ "EARLIEST_BY(cnt, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1, MILLIS_TO_TIMESTAMP(l1), 10), "
|
||||
+ "EARLIEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) "
|
||||
+ "EARLIEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10), EARLIEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) "
|
||||
+ "FROM druid.numfoo",
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
|
@ -677,7 +677,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{1L, 1.0f, "", 2L, 2.0f, "1", 1L, 3.0f, "2", 2L, 4.0f, "21"}
|
||||
new Object[]{1L, 1.0f, "", "", 2L, 2.0f, "1", 1L, 3.0f, "2", 2L, 4.0f, "21", "21"}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -752,10 +752,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
|
||||
testQuery(
|
||||
"SELECT "
|
||||
+ "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), "
|
||||
+ "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), LATEST(dim1, CAST(10 AS INTEGER)), "
|
||||
+ "LATEST(cnt + 1), LATEST(m1 + 1), LATEST(dim1 || CAST(cnt AS VARCHAR), 10), "
|
||||
+ "LATEST_BY(cnt, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1, MILLIS_TO_TIMESTAMP(l1), 10), "
|
||||
+ "LATEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10) "
|
||||
+ "LATEST_BY(cnt + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(m1 + 1, MILLIS_TO_TIMESTAMP(l1)), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), 10), LATEST_BY(dim1 || CAST(cnt AS VARCHAR), MILLIS_TO_TIMESTAMP(l1), CAST(10 AS INTEGER)) "
|
||||
+ "FROM druid.numfoo",
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
|
@ -787,7 +787,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{1L, 6.0f, "abc", 2L, 7.0f, "abc1", 1L, 2.0f, "10.1", 2L, 3.0f, "10.11"}
|
||||
new Object[]{1L, 6.0f, "abc", "abc", 2L, 7.0f, "abc1", 1L, 2.0f, "10.1", 2L, 3.0f, "10.11", "10.11"}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -5612,6 +5612,28 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCountStarWithTimeInCastedIntervalFilter()
|
||||
{
|
||||
testQuery(
|
||||
"SELECT COUNT(*) FROM druid.foo "
|
||||
+ "WHERE TIME_IN_INTERVAL(__time, CAST('2000-01-01/P1Y' AS VARCHAR)) "
|
||||
+ "AND TIME_IN_INTERVAL(CURRENT_TIMESTAMP, '2000/3000') -- Optimized away: always true",
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(querySegmentSpec(Intervals.of("2000-01-01/2001-01-01")))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(aggregators(new CountAggregatorFactory("a0")))
|
||||
.context(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{3L}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCountStarWithTimeInIntervalFilterLosAngeles()
|
||||
{
|
||||
|
@ -5661,7 +5683,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
expected -> {
|
||||
expected.expect(CoreMatchers.instanceOf(DruidException.class));
|
||||
expected.expect(ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
|
||||
"Argument to function 'TIME_IN_INTERVAL' must be a literal (line [1], column [38])")));
|
||||
"Argument to function 'TIME_IN_INTERVAL' must be a literal (line [1], column [63])")));
|
||||
}
|
||||
);
|
||||
}
|
||||
|
@ -14076,6 +14098,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComplexDecodeAgg()
|
||||
{
|
||||
|
@ -14109,6 +14132,43 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComplexDecodeAggWithCastedTypeName()
|
||||
{
|
||||
msqIncompatible();
|
||||
cannotVectorize();
|
||||
testQuery(
|
||||
"SELECT "
|
||||
+ "APPROX_COUNT_DISTINCT_BUILTIN(COMPLEX_DECODE_BASE64(CAST('hyperUnique' AS VARCHAR),PARSE_JSON(TO_JSON_STRING(unique_dim1)))) "
|
||||
+ "FROM druid.foo",
|
||||
ImmutableList.of(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(querySegmentSpec(Filtration.eternity()))
|
||||
.virtualColumns(
|
||||
expressionVirtualColumn(
|
||||
"v0",
|
||||
"complex_decode_base64('hyperUnique',parse_json(to_json_string(\"unique_dim1\")))",
|
||||
ColumnType.ofComplex("hyperUnique")
|
||||
)
|
||||
)
|
||||
.aggregators(
|
||||
new HyperUniquesAggregatorFactory(
|
||||
"a0",
|
||||
"v0",
|
||||
false,
|
||||
true
|
||||
)
|
||||
)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{6L}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@NotYetSupported
|
||||
@Test
|
||||
public void testOrderByAlongWithInternalScanQuery()
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
package org.apache.druid.sql.calcite.expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import it.unimi.dsi.fastutil.ints.IntSets;
|
||||
import org.apache.calcite.rel.type.RelDataType;
|
||||
import org.apache.calcite.runtime.CalciteContextException;
|
||||
import org.apache.calcite.runtime.Resources.ExInst;
|
||||
|
@ -38,7 +37,6 @@ import org.apache.calcite.sql.type.SqlTypeName;
|
|||
import org.apache.calcite.sql.validate.SqlValidator;
|
||||
import org.apache.calcite.sql.validate.SqlValidatorScope;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker;
|
||||
import org.apache.druid.sql.calcite.planner.DruidTypeSystem;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Rule;
|
||||
|
@ -51,7 +49,6 @@ import org.mockito.Mockito;
|
|||
import org.mockito.stubbing.Answer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@RunWith(Enclosed.class)
|
||||
|
@ -65,13 +62,13 @@ public class OperatorConversionsTest
|
|||
@Test
|
||||
public void testGetOperandCountRange()
|
||||
{
|
||||
SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
|
||||
Collections.emptyList(),
|
||||
ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
|
||||
2,
|
||||
IntSets.EMPTY_SET,
|
||||
null
|
||||
);
|
||||
SqlOperandTypeChecker typeChecker = DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames()
|
||||
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
|
||||
.requiredOperandCount(2)
|
||||
.literalOperands()
|
||||
.build();
|
||||
SqlOperandCountRange countRange = typeChecker.getOperandCountRange();
|
||||
Assert.assertEquals(2, countRange.getMin());
|
||||
Assert.assertEquals(3, countRange.getMax());
|
||||
|
@ -80,13 +77,13 @@ public class OperatorConversionsTest
|
|||
@Test
|
||||
public void testIsOptional()
|
||||
{
|
||||
SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
|
||||
Collections.emptyList(),
|
||||
ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
|
||||
2,
|
||||
IntSets.EMPTY_SET,
|
||||
null
|
||||
);
|
||||
SqlOperandTypeChecker typeChecker = DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames()
|
||||
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
|
||||
.requiredOperandCount(2)
|
||||
.literalOperands()
|
||||
.build();
|
||||
Assert.assertFalse(typeChecker.isOptional(0));
|
||||
Assert.assertFalse(typeChecker.isOptional(1));
|
||||
Assert.assertTrue(typeChecker.isOptional(2));
|
||||
|
@ -121,7 +118,7 @@ public class OperatorConversionsTest
|
|||
{
|
||||
SqlFunction function = OperatorConversions
|
||||
.operatorBuilder("testRequiredOperandsOnly")
|
||||
.operandTypeChecker(BasicOperandTypeChecker.builder().operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE).requiredOperandCount(1).build())
|
||||
.operandTypeChecker(DefaultOperandTypeChecker.builder().operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE).requiredOperandCount(1).build())
|
||||
.returnTypeNonNull(SqlTypeName.CHAR)
|
||||
.build();
|
||||
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
|
||||
|
|
Loading…
Reference in New Issue