SQL OperatorConversions: Introduce.aggregatorBuilder, allow CAST-as-literal. (#14249)

* SQL OperatorConversions: Introduce.aggregatorBuilder, allow CAST-as-literal.

Four main changes:

1) Provide aggregatorBuilder, a more consistent way of defining the
   SqlAggFunction we need for all of our SQL aggregators. The mechanism
   is analogous to the one we already use for SQL functions
   (OperatorConversions.operatorBuilder).

2) Allow CASTs of constants to be considered as "literalOperands". This
   fixes an issue where various of our operators are defined with
   OperandTypes.LITERAL as part of their checkers, which doesn't allow
   casts. However, in these cases we generally _do_ want to allow casts.
   The important piece is that the value must be reducible to a constant,
   not that the SQL text is literally a literal.

3) Update DataSketches SQL aggregators to use the new aggregatorBuilder
   functionality. The main user-visible effect here is [2]: the aggregators
   would now accept, for example, "CAST(0.99 AS DOUBLE)" as a literal
   argument. Other aggregators could be updated in a future patch.

4) Rename "requiredOperands" to "requiredOperandCount", because the
   old name was confusing. (It rhymes with "literalOperands" but the
   arguments mean different things.)

* Adjust method calls.
This commit is contained in:
Gian Merlino 2023-06-23 16:25:04 -07:00 committed by GitHub
parent 1d6c9657ec
commit 3d19b748fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 324 additions and 287 deletions

View File

@ -21,24 +21,30 @@ package org.apache.druid.query.aggregation.datasketches.hll.sql;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes; import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import java.util.Collections; import java.util.Collections;
public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{ {
public static final String NAME = "APPROX_COUNT_DISTINCT_DS_HLL"; public static final String NAME = "APPROX_COUNT_DISTINCT_DS_HLL";
private static final SqlAggFunction FUNCTION_INSTANCE =
private static final SqlAggFunction FUNCTION_INSTANCE = new HllSketchApproxCountDistinctSqlAggFunction(); OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "lgK", "tgtHllType")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)
.operandTypeInference(InferTypes.VARCHAR_1024)
.requiredOperandCount(1)
.literalOperands(1, 2)
.returnTypeNonNull(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();
public HllSketchApproxCountDistinctSqlAggregator() public HllSketchApproxCountDistinctSqlAggregator()
{ {
@ -66,30 +72,4 @@ public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlA
) : null ) : null
); );
} }
private static class HllSketchApproxCountDistinctSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE = "'" + NAME + "(column, lgK, tgtHllType)'";
HllSketchApproxCountDistinctSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.BIGINT),
InferTypes.VARCHAR_1024,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)
)
),
SqlFunctionCategory.NUMERIC,
false,
false
);
}
}
} }

View File

@ -47,7 +47,7 @@ public class HllSketchEstimateOperatorConversion implements SqlOperatorConversio
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(FUNCTION_NAME)) .operatorBuilder(StringUtils.toUpperCase(FUNCTION_NAME))
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.BOOLEAN) .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.BOOLEAN)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeInference(ReturnTypes.DOUBLE) .returnTypeInference(ReturnTypes.DOUBLE)
.build(); .build();

View File

@ -46,7 +46,7 @@ public class HllSketchEstimateWithErrorBoundsOperatorConversion implements SqlOp
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(FUNCTION_NAME)) .operatorBuilder(StringUtils.toUpperCase(FUNCTION_NAME))
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER) .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.OTHER) .returnTypeNonNull(SqlTypeName.OTHER)
.build(); .build();

View File

@ -21,22 +21,29 @@ package org.apache.druid.query.aggregation.datasketches.hll.sql;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes; import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import java.util.Collections; import java.util.Collections;
public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{ {
private static final SqlAggFunction FUNCTION_INSTANCE = new HllSketchSqlAggFunction();
private static final String NAME = "DS_HLL"; private static final String NAME = "DS_HLL";
private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "lgK", "tgtHllType")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)
.operandTypeInference(InferTypes.VARCHAR_1024)
.requiredOperandCount(1)
.literalOperands(1, 2)
.returnTypeNonNull(SqlTypeName.OTHER)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.build();
public HllSketchObjectSqlAggregator() public HllSketchObjectSqlAggregator()
{ {
@ -61,30 +68,4 @@ public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator imp
null null
); );
} }
private static class HllSketchSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE = "'" + NAME + "(column, lgK, tgtHllType)'";
HllSketchSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.OTHER),
InferTypes.VARCHAR_1024,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)
)
),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
);
}
}
} }

View File

@ -28,8 +28,6 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
@ -45,19 +43,27 @@ import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
{ {
public static final String CTX_APPROX_QUANTILE_DS_MAX_STREAM_LENGTH = "approxQuantileDsMaxStreamLength"; public static final String CTX_APPROX_QUANTILE_DS_MAX_STREAM_LENGTH = "approxQuantileDsMaxStreamLength";
private static final SqlAggFunction FUNCTION_INSTANCE = new DoublesSketchApproxQuantileSqlAggFunction();
private static final String NAME = "APPROX_QUANTILE_DS"; private static final String NAME = "APPROX_QUANTILE_DS";
private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "probability", "k")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
.returnTypeNonNull(SqlTypeName.DOUBLE)
.requiredOperandCount(2)
.literalOperands(1, 2)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();
@Override @Override
public SqlAggFunction calciteFunction() public SqlAggFunction calciteFunction()
@ -212,34 +218,4 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
DoublesSketchAggregatorFactory.DEFAULT_MAX_STREAM_LENGTH DoublesSketchAggregatorFactory.DEFAULT_MAX_STREAM_LENGTH
); );
} }
private static class DoublesSketchApproxQuantileSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE1 = "'" + NAME + "(column, probability)'";
private static final String SIGNATURE2 = "'" + NAME + "(column, probability, k)'";
DoublesSketchApproxQuantileSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.DOUBLE),
null,
OperandTypes.or(
OperandTypes.and(
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
),
OperandTypes.and(
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
)
),
SqlFunctionCategory.NUMERIC,
false,
false
);
}
}
} }

View File

@ -28,8 +28,6 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
@ -43,17 +41,25 @@ import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
public class DoublesSketchObjectSqlAggregator implements SqlAggregator public class DoublesSketchObjectSqlAggregator implements SqlAggregator
{ {
private static final SqlAggFunction FUNCTION_INSTANCE = new DoublesSketchSqlAggFunction();
private static final String NAME = "DS_QUANTILES_SKETCH"; private static final String NAME = "DS_QUANTILES_SKETCH";
private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "k")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC)
.returnTypeNonNull(SqlTypeName.OTHER)
.requiredOperandCount(1)
.literalOperands(1)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();
@Override @Override
public SqlAggFunction calciteFunction() public SqlAggFunction calciteFunction()
@ -139,30 +145,4 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
null null
); );
} }
private static class DoublesSketchSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE2 = "'" + NAME + "(column, k)'";
DoublesSketchSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.OTHER),
null,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC)
)
),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
);
}
}
} }

View File

@ -21,24 +21,30 @@ package org.apache.druid.query.aggregation.datasketches.theta.sql;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes; import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import java.util.Collections; import java.util.Collections;
public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{ {
public static final String NAME = "APPROX_COUNT_DISTINCT_DS_THETA"; public static final String NAME = "APPROX_COUNT_DISTINCT_DS_THETA";
private static final SqlAggFunction FUNCTION_INSTANCE =
private static final SqlAggFunction FUNCTION_INSTANCE = new ThetaSketchSqlAggFunction(); OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "size")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
.operandTypeInference(InferTypes.VARCHAR_1024)
.requiredOperandCount(1)
.literalOperands(1)
.returnTypeNonNull(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.build();
public ThetaSketchApproxCountDistinctSqlAggregator() public ThetaSketchApproxCountDistinctSqlAggregator()
{ {
@ -66,30 +72,4 @@ public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBase
) : null ) : null
); );
} }
private static class ThetaSketchSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE = "'" + NAME + "(column, size)'";
ThetaSketchSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.BIGINT),
InferTypes.VARCHAR_1024,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
)
),
SqlFunctionCategory.NUMERIC,
false,
false
);
}
}
} }

View File

@ -21,20 +21,28 @@ package org.apache.druid.query.aggregation.datasketches.theta.sql;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes; import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import java.util.Collections; import java.util.Collections;
public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{ {
private static final SqlAggFunction FUNCTION_INSTANCE = new ThetaSketchObjectSqlAggFunction();
private static final String NAME = "DS_THETA"; private static final String NAME = "DS_THETA";
private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "size")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
.operandTypeInference(InferTypes.VARCHAR_1024)
.requiredOperandCount(1)
.literalOperands(1)
.returnTypeInference(ThetaSketchSqlOperators.RETURN_TYPE_INFERENCE)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.build();
public ThetaSketchObjectSqlAggregator() public ThetaSketchObjectSqlAggregator()
{ {
@ -59,30 +67,4 @@ public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator
null null
); );
} }
private static class ThetaSketchObjectSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE = "'" + NAME + "(column, size)'";
ThetaSketchObjectSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ThetaSketchSqlOperators.RETURN_TYPE_INFERENCE,
InferTypes.VARCHAR_1024,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
)
),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
);
}
}
} }

View File

@ -177,18 +177,19 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest
+ " APPROX_COUNT_DISTINCT(SUBSTRING(dim2, 1, 1)),\n" // on extractionFn, using generic A.C.D. + " APPROX_COUNT_DISTINCT(SUBSTRING(dim2, 1, 1)),\n" // on extractionFn, using generic A.C.D.
+ " COUNT(DISTINCT SUBSTRING(dim2, 1, 1) || 'x'),\n" // on expression, using COUNT DISTINCT + " COUNT(DISTINCT SUBSTRING(dim2, 1, 1) || 'x'),\n" // on expression, using COUNT DISTINCT
+ " APPROX_COUNT_DISTINCT_DS_HLL(hllsketch_dim1, 21, 'HLL_8'),\n" // on native HllSketch column + " APPROX_COUNT_DISTINCT_DS_HLL(hllsketch_dim1, 21, 'HLL_8'),\n" // on native HllSketch column
+ " APPROX_COUNT_DISTINCT_DS_HLL(hllsketch_dim1)\n" // on native HllSketch column + " APPROX_COUNT_DISTINCT_DS_HLL(hllsketch_dim1),\n" // on native HllSketch column
+ " APPROX_COUNT_DISTINCT_DS_HLL(hllsketch_dim1, CAST(21 AS BIGINT))\n" // also native column
+ "FROM druid.foo"; + "FROM druid.foo";
final List<Object[]> expectedResults; final List<Object[]> expectedResults;
if (NullHandling.replaceWithDefault()) { if (NullHandling.replaceWithDefault()) {
expectedResults = ImmutableList.of( expectedResults = ImmutableList.of(
new Object[]{6L, 2L, 2L, 1L, 2L, 5L, 5L} new Object[]{6L, 2L, 2L, 1L, 2L, 5L, 5L, 5L}
); );
} else { } else {
expectedResults = ImmutableList.of( expectedResults = ImmutableList.of(
new Object[]{6L, 2L, 2L, 1L, 1L, 5L, 5L} new Object[]{6L, 2L, 2L, 1L, 1L, 5L, 5L, 5L}
); );
} }
@ -252,7 +253,8 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest
ROUND ROUND
), ),
new HllSketchMergeAggregatorFactory("a5", "hllsketch_dim1", 21, "HLL_8", null, ROUND), new HllSketchMergeAggregatorFactory("a5", "hllsketch_dim1", 21, "HLL_8", null, ROUND),
new HllSketchMergeAggregatorFactory("a6", "hllsketch_dim1", null, null, null, ROUND) new HllSketchMergeAggregatorFactory("a6", "hllsketch_dim1", null, null, null, ROUND),
new HllSketchMergeAggregatorFactory("a7", "hllsketch_dim1", 21, "HLL_4", null, ROUND)
) )
) )
.context(QUERY_CONTEXT_DEFAULT) .context(QUERY_CONTEXT_DEFAULT)

View File

@ -207,7 +207,7 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.01),\n" + "APPROX_QUANTILE_DS(qsketch_m1, 0.01),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.5, 64),\n" + "APPROX_QUANTILE_DS(qsketch_m1, 0.5, 64),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.98, 256),\n" + "APPROX_QUANTILE_DS(qsketch_m1, 0.98, 256),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.99),\n" + "APPROX_QUANTILE_DS(qsketch_m1, CAST(0.99 AS DOUBLE)),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.99) FILTER(WHERE dim1 = 'abc'),\n" + "APPROX_QUANTILE_DS(qsketch_m1, 0.99) FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.999) FILTER(WHERE dim1 <> 'abc'),\n" + "APPROX_QUANTILE_DS(qsketch_m1, 0.999) FILTER(WHERE dim1 <> 'abc'),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.999) FILTER(WHERE dim1 = 'abc')\n" + "APPROX_QUANTILE_DS(qsketch_m1, 0.999) FILTER(WHERE dim1 = 'abc')\n"

View File

@ -42,7 +42,7 @@ public class SleepOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("SLEEP") .operatorBuilder("SLEEP")
.operandTypes(SqlTypeFamily.NUMERIC) .operandTypes(SqlTypeFamily.NUMERIC)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNullable(SqlTypeName.VARCHAR) // always null .returnTypeNullable(SqlTypeName.VARCHAR) // always null
.functionCategory(SqlFunctionCategory.TIMEDATE) .functionCategory(SqlFunctionCategory.TIMEDATE)
.build(); .build();

View File

@ -42,7 +42,7 @@ public class SleepOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("SLEEP") .operatorBuilder("SLEEP")
.operandTypes(SqlTypeFamily.NUMERIC) .operandTypes(SqlTypeFamily.NUMERIC)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNullable(SqlTypeName.VARCHAR) // always null .returnTypeNullable(SqlTypeName.VARCHAR) // always null
.functionCategory(SqlFunctionCategory.TIMEDATE) .functionCategory(SqlFunctionCategory.TIMEDATE)
.build(); .build();

View File

@ -32,6 +32,7 @@ import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.CalciteException; import org.apache.calcite.runtime.CalciteException;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
@ -49,6 +50,7 @@ import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.apache.calcite.util.Optionality;
import org.apache.calcite.util.Static; import org.apache.calcite.util.Static;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
@ -57,12 +59,14 @@ import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator; import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.DruidTypeSystem; import org.apache.druid.sql.calcite.planner.DruidTypeSystem;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.IntStream; import java.util.stream.IntStream;
@ -316,23 +320,33 @@ public class OperatorConversions
* Returns a builder that helps {@link SqlOperatorConversion} implementations create the {@link SqlFunction} * Returns a builder that helps {@link SqlOperatorConversion} implementations create the {@link SqlFunction}
* objects they need to return from {@link SqlOperatorConversion#calciteOperator()}. * objects they need to return from {@link SqlOperatorConversion#calciteOperator()}.
*/ */
public static OperatorBuilder operatorBuilder(final String name) public static OperatorBuilder<SqlFunction> operatorBuilder(final String name)
{ {
return new OperatorBuilder(name); return new OperatorBuilder<>(name);
} }
public static class OperatorBuilder /**
* Returns a builder that helps {@link SqlAggregator} implementations create the {@link SqlAggFunction} objects
* they need to return from {@link SqlAggregator#calciteFunction()}.
*/
public static OperatorBuilder<SqlAggFunction> aggregatorBuilder(final String name)
{ {
private final String name; return new AggregatorBuilder(name);
private SqlKind kind = SqlKind.OTHER_FUNCTION; }
private SqlReturnTypeInference returnTypeInference;
private SqlFunctionCategory functionCategory = SqlFunctionCategory.USER_DEFINED_FUNCTION; public static class OperatorBuilder<T extends SqlFunction>
{
protected final String name;
protected SqlKind kind = SqlKind.OTHER_FUNCTION;
protected SqlReturnTypeInference returnTypeInference;
protected SqlFunctionCategory functionCategory = SqlFunctionCategory.USER_DEFINED_FUNCTION;
// For operand type checking // For operand type checking
private SqlOperandTypeChecker operandTypeChecker; private SqlOperandTypeChecker operandTypeChecker;
private List<String> operandNames = Collections.emptyList();
private List<SqlTypeFamily> operandTypes; private List<SqlTypeFamily> operandTypes;
private Integer requiredOperands = null; private Integer requiredOperandCount;
private int[] literalOperands = null; private int[] literalOperands;
private SqlOperandTypeInference operandTypeInference; private SqlOperandTypeInference operandTypeInference;
private OperatorBuilder(final String name) private OperatorBuilder(final String name)
@ -348,7 +362,7 @@ public class OperatorConversions
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} * {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one. * must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/ */
public OperatorBuilder returnTypeNonNull(final SqlTypeName typeName) public OperatorBuilder<T> returnTypeNonNull(final SqlTypeName typeName)
{ {
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times"); Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
@ -365,7 +379,7 @@ public class OperatorConversions
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} * {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one. * must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/ */
public OperatorBuilder returnTypeNullable(final SqlTypeName typeName) public OperatorBuilder<T> returnTypeNullable(final SqlTypeName typeName)
{ {
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times"); Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
@ -382,7 +396,7 @@ public class OperatorConversions
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} * {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one. * must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/ */
public OperatorBuilder returnTypeCascadeNullable(final SqlTypeName typeName) public OperatorBuilder<T> returnTypeCascadeNullable(final SqlTypeName typeName)
{ {
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times"); Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
this.returnTypeInference = ReturnTypes.cascade(ReturnTypes.explicit(typeName), SqlTypeTransforms.TO_NULLABLE); this.returnTypeInference = ReturnTypes.cascade(ReturnTypes.explicit(typeName), SqlTypeTransforms.TO_NULLABLE);
@ -396,7 +410,7 @@ public class OperatorConversions
* {@link #returnTypeArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be * {@link #returnTypeArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be
* used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one. * used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/ */
public OperatorBuilder returnTypeArrayWithNullableElements(final SqlTypeName elementTypeName) public OperatorBuilder<T> returnTypeArrayWithNullableElements(final SqlTypeName elementTypeName)
{ {
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times"); Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
@ -413,7 +427,7 @@ public class OperatorConversions
* {@link #returnTypeArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be * {@link #returnTypeArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be
* used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one. * used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/ */
public OperatorBuilder returnTypeNullableArrayWithNullableElements(final SqlTypeName elementTypeName) public OperatorBuilder<T> returnTypeNullableArrayWithNullableElements(final SqlTypeName elementTypeName)
{ {
this.returnTypeInference = ReturnTypes.cascade( this.returnTypeInference = ReturnTypes.cascade(
opBinding -> Calcites.createSqlArrayTypeWithNullability( opBinding -> Calcites.createSqlArrayTypeWithNullability(
@ -434,7 +448,7 @@ public class OperatorConversions
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} * {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one. * must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/ */
public OperatorBuilder returnTypeInference(final SqlReturnTypeInference returnTypeInference) public OperatorBuilder<T> returnTypeInference(final SqlReturnTypeInference returnTypeInference)
{ {
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times"); Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
@ -447,7 +461,7 @@ public class OperatorConversions
* *
* The default, if not provided, is {@link SqlFunctionCategory#USER_DEFINED_FUNCTION}. * The default, if not provided, is {@link SqlFunctionCategory#USER_DEFINED_FUNCTION}.
*/ */
public OperatorBuilder functionCategory(final SqlFunctionCategory functionCategory) public OperatorBuilder<T> functionCategory(final SqlFunctionCategory functionCategory)
{ {
this.functionCategory = functionCategory; this.functionCategory = functionCategory;
return this; return this;
@ -459,21 +473,32 @@ public class OperatorConversions
* One of {@link #operandTypes(SqlTypeFamily...)} or {@link #operandTypeChecker(SqlOperandTypeChecker)} must be used * One of {@link #operandTypes(SqlTypeFamily...)} or {@link #operandTypeChecker(SqlOperandTypeChecker)} must be used
* before calling {@link #build()}. These methods cannot be mixed; you must call exactly one. * before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/ */
public OperatorBuilder operandTypeChecker(final SqlOperandTypeChecker operandTypeChecker) public OperatorBuilder<T> operandTypeChecker(final SqlOperandTypeChecker operandTypeChecker)
{ {
this.operandTypeChecker = operandTypeChecker; this.operandTypeChecker = operandTypeChecker;
return this; return this;
} }
/**
* Signifies that a function accepts operands with the provided names. This is used to implement
* {@link SqlOperandTypeChecker#getAllowedSignatures(SqlOperator, String)}. If not provided, the
* {@link #operandTypes} are used instead.
*/
public OperatorBuilder<T> operandNames(final String... operandNames)
{
this.operandNames = Arrays.asList(operandNames);
return this;
}
/** /**
* Signifies that a function accepts operands of type family given by {@param operandTypes}. * Signifies that a function accepts operands of type family given by {@param operandTypes}.
* *
* May be used in conjunction with {@link #requiredOperands(int)} and {@link #literalOperands(int...)} in order * May be used in conjunction with {@link #requiredOperandCount(int)} and {@link #literalOperands(int...)} in order
* to further refine operand checking logic. * to further refine operand checking logic.
* *
* For deeper control, use {@link #operandTypeChecker(SqlOperandTypeChecker)} instead. * For deeper control, use {@link #operandTypeChecker(SqlOperandTypeChecker)} instead.
*/ */
public OperatorBuilder operandTypes(final SqlTypeFamily... operandTypes) public OperatorBuilder<T> operandTypes(final SqlTypeFamily... operandTypes)
{ {
this.operandTypes = Arrays.asList(operandTypes); this.operandTypes = Arrays.asList(operandTypes);
return this; return this;
@ -489,67 +514,97 @@ public class OperatorConversions
* Must be used in conjunction with {@link #operandTypes(SqlTypeFamily...)}; this method is not compatible with * Must be used in conjunction with {@link #operandTypes(SqlTypeFamily...)}; this method is not compatible with
* {@link #operandTypeChecker(SqlOperandTypeChecker)}. * {@link #operandTypeChecker(SqlOperandTypeChecker)}.
*/ */
public OperatorBuilder requiredOperands(final int requiredOperands) public OperatorBuilder<T> requiredOperandCount(final int requiredOperandCount)
{ {
this.requiredOperands = requiredOperands; this.requiredOperandCount = requiredOperandCount;
return this; return this;
} }
/**
* Alias for {@link #requiredOperandCount(int)}. Deprecated because it means "operand count" rather than
* "specific operands", and therefore the name can cause confusion with {@link #literalOperands(int...)}. The latter
* really does mean "specific operands".
*/
@Deprecated
@SuppressWarnings("unused") // For compatibility with existing extensions
public OperatorBuilder<T> requiredOperands(final int requiredOperands)
{
return requiredOperandCount(requiredOperands);
}
/** /**
* Signifies that the operands at positions given by {@code literalOperands} must be literals. * Signifies that the operands at positions given by {@code literalOperands} must be literals.
* *
* Must be used in conjunction with {@link #operandTypes(SqlTypeFamily...)}; this method is not compatible with * Must be used in conjunction with {@link #operandTypes(SqlTypeFamily...)}; this method is not compatible with
* {@link #operandTypeChecker(SqlOperandTypeChecker)}. * {@link #operandTypeChecker(SqlOperandTypeChecker)}.
*/ */
public OperatorBuilder literalOperands(final int... literalOperands) public OperatorBuilder<T> literalOperands(final int... literalOperands)
{ {
this.literalOperands = literalOperands; this.literalOperands = literalOperands;
return this; return this;
} }
public OperatorBuilder operandTypeInference(SqlOperandTypeInference operandTypeInference) public OperatorBuilder<T> operandTypeInference(SqlOperandTypeInference operandTypeInference)
{ {
this.operandTypeInference = operandTypeInference; this.operandTypeInference = operandTypeInference;
return this; return this;
} }
public OperatorBuilder sqlKind(SqlKind kind)
{
this.kind = kind;
return this;
}
/** /**
* Creates a {@link SqlFunction} from this builder. * Creates a {@link SqlFunction} from this builder.
*/ */
public SqlFunction build() @SuppressWarnings("unchecked")
public T build()
{
final IntSet nullableOperands = buildNullableOperands();
return (T) new SqlFunction(
name,
kind,
Preconditions.checkNotNull(returnTypeInference, "returnTypeInference"),
buildOperandTypeInference(nullableOperands),
buildOperandTypeChecker(nullableOperands),
functionCategory
);
}
protected IntSet buildNullableOperands()
{ {
// Create "nullableOperands" set including all optional arguments. // Create "nullableOperands" set including all optional arguments.
final IntSet nullableOperands = new IntArraySet(); final IntSet nullableOperands = new IntArraySet();
if (requiredOperands != null) { if (requiredOperandCount != null) {
IntStream.range(requiredOperands, operandTypes.size()).forEach(nullableOperands::add); IntStream.range(requiredOperandCount, operandTypes.size()).forEach(nullableOperands::add);
} }
return nullableOperands;
}
final SqlOperandTypeChecker theOperandTypeChecker; protected SqlOperandTypeChecker buildOperandTypeChecker(final IntSet nullableOperands)
{
if (operandTypeChecker == null) { if (operandTypeChecker == null) {
theOperandTypeChecker = new DefaultOperandTypeChecker( return new DefaultOperandTypeChecker(
operandNames,
operandTypes, operandTypes,
requiredOperands == null ? operandTypes.size() : requiredOperands, requiredOperandCount == null ? operandTypes.size() : requiredOperandCount,
nullableOperands, nullableOperands,
literalOperands literalOperands
); );
} else if (operandTypes == null && requiredOperands == null && literalOperands == null) { } else if (operandNames.isEmpty()
theOperandTypeChecker = operandTypeChecker; && operandTypes == null
&& requiredOperandCount == null
&& literalOperands == null) {
return operandTypeChecker;
} else { } else {
throw new ISE( throw new ISE(
"Cannot have both 'operandTypeChecker' and 'operandTypes' / 'requiredOperands' / 'literalOperands'" "Cannot have both 'operandTypeChecker' and "
+ "'operandNames' / 'operandTypes' / 'requiredOperands' / 'literalOperands'"
); );
} }
}
protected SqlOperandTypeInference buildOperandTypeInference(final IntSet nullableOperands)
{
if (operandTypeInference == null) { if (operandTypeInference == null) {
SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands); SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands);
operandTypeInference = (callBinding, returnType, types) -> { return (callBinding, returnType, types) -> {
for (int i = 0; i < types.length; i++) { for (int i = 0; i < types.length; i++) {
// calcite sql validate tries to do bad things to dynamic parameters if the type is inferred to be a string // calcite sql validate tries to do bad things to dynamic parameters if the type is inferred to be a string
if (callBinding.operand(i).isA(ImmutableSet.of(SqlKind.DYNAMIC_PARAM))) { if (callBinding.operand(i).isA(ImmutableSet.of(SqlKind.DYNAMIC_PARAM))) {
@ -562,15 +617,49 @@ public class OperatorConversions
} }
} }
}; };
} else {
return operandTypeInference;
} }
return new SqlFunction( }
name, }
kind,
Preconditions.checkNotNull(returnTypeInference, "returnTypeInference"), public static class AggregatorBuilder extends OperatorBuilder<SqlAggFunction>
operandTypeInference, {
theOperandTypeChecker, public AggregatorBuilder(String name)
functionCategory {
); super(name);
}
/**
* Create a {@link SqlAggFunction} from this builder.
*/
@Override
public SqlAggFunction build()
{
final IntSet nullableOperands = buildNullableOperands();
final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference(nullableOperands);
final SqlOperandTypeChecker operandTypeChecker = buildOperandTypeChecker(nullableOperands);
class DruidSqlAggFunction extends SqlAggFunction
{
public DruidSqlAggFunction()
{
super(
name,
null,
AggregatorBuilder.this.kind,
returnTypeInference,
operandTypeInference,
operandTypeChecker,
functionCategory,
false,
false,
Optionality.FORBIDDEN
);
}
}
return new DruidSqlAggFunction();
} }
} }
@ -655,6 +744,11 @@ public class OperatorConversions
@VisibleForTesting @VisibleForTesting
static class DefaultOperandTypeChecker implements SqlOperandTypeChecker static class DefaultOperandTypeChecker implements SqlOperandTypeChecker
{ {
/**
* Operand names for {@link #getAllowedSignatures(SqlOperator, String)}. May be empty, in which case the
* {@link #operandTypes} are used instead.
*/
private final List<String> operandNames;
private final List<SqlTypeFamily> operandTypes; private final List<SqlTypeFamily> operandTypes;
private final int requiredOperands; private final int requiredOperands;
private final IntSet nullableOperands; private final IntSet nullableOperands;
@ -662,6 +756,7 @@ public class OperatorConversions
@VisibleForTesting @VisibleForTesting
DefaultOperandTypeChecker( DefaultOperandTypeChecker(
final List<String> operandNames,
final List<SqlTypeFamily> operandTypes, final List<SqlTypeFamily> operandTypes,
final int requiredOperands, final int requiredOperands,
final IntSet nullableOperands, final IntSet nullableOperands,
@ -669,10 +764,15 @@ public class OperatorConversions
) )
{ {
Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0); Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0);
this.operandNames = Preconditions.checkNotNull(operandNames, "operandNames");
this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes"); this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes");
this.requiredOperands = requiredOperands; this.requiredOperands = requiredOperands;
this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands"); this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands");
if (!operandNames.isEmpty() && operandNames.size() != operandTypes.size()) {
throw new ISE("Operand name count[%s] and type count[%s] must match", operandNames.size(), operandTypes.size());
}
if (literalOperands == null) { if (literalOperands == null) {
this.literalOperands = IntSets.EMPTY_SET; this.literalOperands = IntSets.EMPTY_SET;
} else { } else {
@ -688,8 +788,8 @@ public class OperatorConversions
final SqlNode operand = callBinding.operands().get(i); final SqlNode operand = callBinding.operands().get(i);
if (literalOperands.contains(i)) { if (literalOperands.contains(i)) {
// Verify that 'operand' is a literal. // Verify that 'operand' is a literal. Allow CAST, since we can reduce these away later.
if (!SqlUtil.isLiteral(operand)) { if (!SqlUtil.isLiteral(operand, true)) {
return throwOrReturn( return throwOrReturn(
throwOnFailure, throwOnFailure,
callBinding, callBinding,
@ -739,7 +839,25 @@ public class OperatorConversions
@Override @Override
public String getAllowedSignatures(SqlOperator op, String opName) public String getAllowedSignatures(SqlOperator op, String opName)
{ {
return SqlUtil.getAliasedSignature(op, opName, operandTypes); final List<?> operands = !operandNames.isEmpty() ? operandNames : operandTypes;
final StringBuilder ret = new StringBuilder();
ret.append("'");
ret.append(opName);
ret.append("(");
for (int i = 0; i < operands.size(); i++) {
if (i > 0) {
ret.append(", ");
}
if (i >= requiredOperands) {
ret.append("[");
}
ret.append("<").append(operands.get(i)).append(">");
}
for (int i = requiredOperands; i < operands.size(); i++) {
ret.append("]");
}
ret.append(")'");
return ret.toString();
} }
@Override @Override
@ -772,7 +890,7 @@ public class OperatorConversions
{ {
return new DirectOperatorConversion( return new DirectOperatorConversion(
operatorBuilder(sqlOperator) operatorBuilder(sqlOperator)
.requiredOperands(1) .requiredOperandCount(1)
.operandTypes(SqlTypeFamily.NUMERIC) .operandTypes(SqlTypeFamily.NUMERIC)
.returnTypeNullable(SqlTypeName.BIGINT) .returnTypeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.NUMERIC) .functionCategory(SqlFunctionCategory.NUMERIC)
@ -785,7 +903,7 @@ public class OperatorConversions
{ {
return new DirectOperatorConversion( return new DirectOperatorConversion(
operatorBuilder(sqlOperator) operatorBuilder(sqlOperator)
.requiredOperands(2) .requiredOperandCount(2)
.operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC) .operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)
.returnTypeNullable(SqlTypeName.BIGINT) .returnTypeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.NUMERIC) .functionCategory(SqlFunctionCategory.NUMERIC)
@ -798,7 +916,7 @@ public class OperatorConversions
{ {
return new DirectOperatorConversion( return new DirectOperatorConversion(
operatorBuilder(StringUtils.toUpperCase(sqlOperator)) operatorBuilder(StringUtils.toUpperCase(sqlOperator))
.requiredOperands(1) .requiredOperandCount(1)
.operandTypes(SqlTypeFamily.NUMERIC) .operandTypes(SqlTypeFamily.NUMERIC)
.returnTypeNullable(SqlTypeName.DOUBLE) .returnTypeNullable(SqlTypeName.DOUBLE)
.functionCategory(SqlFunctionCategory.NUMERIC) .functionCategory(SqlFunctionCategory.NUMERIC)

View File

@ -40,7 +40,7 @@ public class BTrimOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR) .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1) .requiredOperandCount(1)
.build(); .build();
@Override @Override

View File

@ -80,7 +80,7 @@ public class ContainsOperatorConversion extends DirectOperatorConversion
return OperatorConversions return OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(functionName)) .operatorBuilder(StringUtils.toUpperCase(functionName))
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(2) .requiredOperandCount(2)
.literalOperands(1) .literalOperands(1)
.returnTypeNonNull(SqlTypeName.BOOLEAN) .returnTypeNonNull(SqlTypeName.BOOLEAN)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)

View File

@ -67,7 +67,7 @@ public class DateTruncOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("DATE_TRUNC") .operatorBuilder("DATE_TRUNC")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP)
.requiredOperands(2) .requiredOperandCount(2)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE) .functionCategory(SqlFunctionCategory.TIMEDATE)
.build(); .build();

View File

@ -40,7 +40,7 @@ public class LPadOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR) .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(2) .requiredOperandCount(2)
.build(); .build();
@Override @Override

View File

@ -40,7 +40,7 @@ public class LTrimOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR) .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1) .requiredOperandCount(1)
.build(); .build();
@Override @Override

View File

@ -40,7 +40,7 @@ public class ParseLongOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
.returnTypeCascadeNullable(SqlTypeName.BIGINT) .returnTypeCascadeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1) .requiredOperandCount(1)
.build(); .build();
@Override @Override

View File

@ -40,7 +40,7 @@ public class RPadOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR) .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(2) .requiredOperandCount(2)
.build(); .build();
@Override @Override

View File

@ -40,7 +40,7 @@ public class RTrimOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR) .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1) .requiredOperandCount(1)
.build(); .build();
@Override @Override

View File

@ -38,7 +38,7 @@ public class RegexpExtractOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("REGEXP_EXTRACT") .operatorBuilder("REGEXP_EXTRACT")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
.requiredOperands(2) .requiredOperandCount(2)
.literalOperands(1, 2) .literalOperands(1, 2)
.returnTypeNullable(SqlTypeName.VARCHAR) .returnTypeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)

View File

@ -44,7 +44,7 @@ public class RegexpLikeOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("REGEXP_LIKE") .operatorBuilder("REGEXP_LIKE")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(2) .requiredOperandCount(2)
.literalOperands(1) .literalOperands(1)
.returnTypeNonNull(SqlTypeName.BOOLEAN) .returnTypeNonNull(SqlTypeName.BOOLEAN)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)

View File

@ -35,7 +35,7 @@ public class RoundOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("ROUND") .operatorBuilder("ROUND")
.operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER) .operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeInference(ReturnTypes.ARG0) .returnTypeInference(ReturnTypes.ARG0)
.functionCategory(SqlFunctionCategory.NUMERIC) .functionCategory(SqlFunctionCategory.NUMERIC)
.build(); .build();

View File

@ -43,7 +43,7 @@ public class SubstringOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)
.returnTypeInference(ReturnTypes.ARG0) .returnTypeInference(ReturnTypes.ARG0)
.requiredOperands(2) .requiredOperandCount(2)
.build(); .build();
@Override @Override

View File

@ -31,7 +31,7 @@ public class TextcatOperatorConversion extends DirectOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("textcat") .operatorBuilder("textcat")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(2) .requiredOperandCount(2)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR) .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING) .functionCategory(SqlFunctionCategory.STRING)
.build(); .build();

View File

@ -41,7 +41,7 @@ public class TimeCeilOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIME_CEIL") .operatorBuilder("TIME_CEIL")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)
.requiredOperands(2) .requiredOperandCount(2)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE) .functionCategory(SqlFunctionCategory.TIMEDATE)
.build(); .build();

View File

@ -43,7 +43,7 @@ public class TimeExtractOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIME_EXTRACT") .operatorBuilder("TIME_EXTRACT")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(2) .requiredOperandCount(2)
.returnTypeCascadeNullable(SqlTypeName.BIGINT) .returnTypeCascadeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.TIMEDATE) .functionCategory(SqlFunctionCategory.TIMEDATE)
.build(); .build();

View File

@ -56,7 +56,7 @@ public class TimeFloorOperatorConversion implements SqlOperatorConversion
public static final SqlFunction SQL_FUNCTION = OperatorConversions public static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(SQL_FUNCTION_NAME) .operatorBuilder(SQL_FUNCTION_NAME)
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)
.requiredOperands(2) .requiredOperandCount(2)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE) .functionCategory(SqlFunctionCategory.TIMEDATE)
.build(); .build();

View File

@ -46,7 +46,7 @@ public class TimeFormatOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIME_FORMAT") .operatorBuilder("TIME_FORMAT")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR) .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.TIMEDATE) .functionCategory(SqlFunctionCategory.TIMEDATE)
.build(); .build();

View File

@ -43,7 +43,7 @@ public class TimeParseOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIME_PARSE") .operatorBuilder("TIME_PARSE")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNullable(SqlTypeName.TIMESTAMP) .returnTypeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE) .functionCategory(SqlFunctionCategory.TIMEDATE)
.build(); .build();

View File

@ -43,7 +43,7 @@ public class TimeShiftOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIME_SHIFT") .operatorBuilder("TIME_SHIFT")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.requiredOperands(3) .requiredOperandCount(3)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE) .functionCategory(SqlFunctionCategory.TIMEDATE)
.build(); .build();

View File

@ -37,7 +37,7 @@ public class TruncateOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TRUNCATE") .operatorBuilder("TRUNCATE")
.operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER) .operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeInference(ReturnTypes.ARG0) .returnTypeInference(ReturnTypes.ARG0)
.functionCategory(SqlFunctionCategory.NUMERIC) .functionCategory(SqlFunctionCategory.NUMERIC)
.build(); .build();

View File

@ -51,6 +51,7 @@ import org.mockito.stubbing.Answer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
@RunWith(Enclosed.class) @RunWith(Enclosed.class)
@ -65,6 +66,7 @@ public class OperatorConversionsTest
public void testGetOperandCountRange() public void testGetOperandCountRange()
{ {
SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker( SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
Collections.emptyList(),
ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER), ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
2, 2,
IntSets.EMPTY_SET, IntSets.EMPTY_SET,
@ -79,6 +81,7 @@ public class OperatorConversionsTest
public void testIsOptional() public void testIsOptional()
{ {
SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker( SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
Collections.emptyList(),
ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER), ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
2, 2,
IntSets.EMPTY_SET, IntSets.EMPTY_SET,
@ -95,7 +98,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testAllowFullOperands") .operatorBuilder("testAllowFullOperands")
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE) .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -119,7 +122,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testRequiredOperandsOnly") .operatorBuilder("testRequiredOperandsOnly")
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE) .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -140,7 +143,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testLiteralOperandCheckLiteral") .operatorBuilder("testLiteralOperandCheckLiteral")
.operandTypes(SqlTypeFamily.INTEGER) .operandTypes(SqlTypeFamily.INTEGER)
.requiredOperands(1) .requiredOperandCount(1)
.literalOperands(0) .literalOperands(0)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
@ -162,7 +165,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testLiteralOperandCheckLiteralThrow") .operatorBuilder("testLiteralOperandCheckLiteralThrow")
.operandTypes(SqlTypeFamily.INTEGER) .operandTypes(SqlTypeFamily.INTEGER)
.requiredOperands(1) .requiredOperandCount(1)
.literalOperands(0) .literalOperands(0)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
@ -184,7 +187,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testAnyTypeOperand") .operatorBuilder("testAnyTypeOperand")
.operandTypes(SqlTypeFamily.ANY) .operandTypes(SqlTypeFamily.ANY)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -205,7 +208,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testCastableFromDatetimeFamilyToTimestamp") .operatorBuilder("testCastableFromDatetimeFamilyToTimestamp")
.operandTypes(SqlTypeFamily.DATETIME) .operandTypes(SqlTypeFamily.DATETIME)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -235,7 +238,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableOperand") .operatorBuilder("testNullForNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -259,7 +262,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testNullLiteralForNullableOperand") .operatorBuilder("testNullLiteralForNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -283,7 +286,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableNonnull") .operatorBuilder("testNullForNullableNonnull")
.operandTypes(SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -304,7 +307,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableCascade") .operatorBuilder("testNullForNullableCascade")
.operandTypes(SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeCascadeNullable(SqlTypeName.CHAR) .returnTypeCascadeNullable(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -325,7 +328,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableNonnull") .operatorBuilder("testNullForNullableNonnull")
.operandTypes(SqlTypeFamily.CHARACTER) .operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNullable(SqlTypeName.CHAR) .returnTypeNullable(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -346,7 +349,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNonNullableOperand") .operatorBuilder("testNullForNonNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -372,7 +375,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testNullLiteralForNonNullableOperand") .operatorBuilder("testNullLiteralForNonNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1) .requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -398,7 +401,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions SqlFunction function = OperatorConversions
.operatorBuilder("testNonCastableType") .operatorBuilder("testNonCastableType")
.operandTypes(SqlTypeFamily.CURSOR, SqlTypeFamily.INTERVAL_DAY_TIME) .operandTypes(SqlTypeFamily.CURSOR, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(2) .requiredOperandCount(2)
.returnTypeNonNull(SqlTypeName.CHAR) .returnTypeNonNull(SqlTypeName.CHAR)
.build(); .build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -418,6 +421,41 @@ public class OperatorConversionsTest
); );
} }
@Test
public void testSignatureWithNames()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testSignatureWithNames")
.operandNames("x", "y", "z")
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE, SqlTypeFamily.ANY)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertEquals(
"'testSignatureWithNames(<x>, [<y>, [<z>]])'",
typeChecker.getAllowedSignatures(function, function.getName())
);
}
@Test
public void testSignatureWithoutNames()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testSignatureWithoutNames")
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE, SqlTypeFamily.ANY)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertEquals(
"'testSignatureWithoutNames(<INTEGER>, [<DATE>, [<ANY>]])'",
typeChecker.getAllowedSignatures(function, function.getName())
);
}
private static SqlCallBinding mockCallBinding( private static SqlCallBinding mockCallBinding(
SqlFunction function, SqlFunction function,
List<OperandSpec> actualOperands List<OperandSpec> actualOperands

View File

@ -68,7 +68,7 @@ public class DruidOperatorTableTest
final SqlOperator operator1 = OperatorConversions final SqlOperator operator1 = OperatorConversions
.operatorBuilder("FOO") .operatorBuilder("FOO")
.operandTypes(SqlTypeFamily.ANY) .operandTypes(SqlTypeFamily.ANY)
.requiredOperands(0) .requiredOperandCount(0)
.returnTypeInference( .returnTypeInference(
opBinding -> RowSignatures.makeComplexType( opBinding -> RowSignatures.makeComplexType(
opBinding.getTypeFactory(), opBinding.getTypeFactory(),

View File

@ -69,7 +69,7 @@ public class DruidRexExecutorTest extends InitializedNullHandlingTest
private static final SqlOperator OPERATOR = OperatorConversions private static final SqlOperator OPERATOR = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase("hyper_unique")) .operatorBuilder(StringUtils.toUpperCase("hyper_unique"))
.operandTypes(SqlTypeFamily.ANY) .operandTypes(SqlTypeFamily.ANY)
.requiredOperands(0) .requiredOperandCount(0)
.returnTypeInference( .returnTypeInference(
opBinding -> RowSignatures.makeComplexType( opBinding -> RowSignatures.makeComplexType(
opBinding.getTypeFactory(), opBinding.getTypeFactory(),

View File

@ -136,7 +136,7 @@ public class InformationSchemaTest extends BaseCalciteQueryTest
Assert.assertNotNull(rows); Assert.assertNotNull(rows);
Assert.assertEquals("There should be exactly 2 rows; any non-function syntax operator should get filtered out", Assert.assertEquals("There should be exactly 2 rows; any non-function syntax operator should get filtered out",
2, rows.size()); 2, rows.size());
Object[] expectedRow1 = {"druid", "INFORMATION_SCHEMA", "FOO", "FUNCTION", "NO", "'FOO(<ANY>)'"}; Object[] expectedRow1 = {"druid", "INFORMATION_SCHEMA", "FOO", "FUNCTION", "NO", "'FOO([<ANY>])'"};
Assert.assertTrue(rows.stream().anyMatch(row -> Arrays.equals(row, expectedRow1))); Assert.assertTrue(rows.stream().anyMatch(row -> Arrays.equals(row, expectedRow1)));
Object[] expectedRow2 = {"druid", "INFORMATION_SCHEMA", "BAR", "FUNCTION", "NO", "'BAR(<INTEGER>, <INTEGER>)'"}; Object[] expectedRow2 = {"druid", "INFORMATION_SCHEMA", "BAR", "FUNCTION", "NO", "'BAR(<INTEGER>, <INTEGER>)'"};
@ -166,7 +166,7 @@ public class InformationSchemaTest extends BaseCalciteQueryTest
final SqlOperator operator1 = OperatorConversions final SqlOperator operator1 = OperatorConversions
.operatorBuilder("FOO") .operatorBuilder("FOO")
.operandTypes(SqlTypeFamily.ANY) .operandTypes(SqlTypeFamily.ANY)
.requiredOperands(0) .requiredOperandCount(0)
.returnTypeInference( .returnTypeInference(
opBinding -> RowSignatures.makeComplexType( opBinding -> RowSignatures.makeComplexType(
opBinding.getTypeFactory(), opBinding.getTypeFactory(),
@ -182,7 +182,7 @@ public class InformationSchemaTest extends BaseCalciteQueryTest
.operatorBuilder("BAR") .operatorBuilder("BAR")
.operandTypes(SqlTypeFamily.NUMERIC) .operandTypes(SqlTypeFamily.NUMERIC)
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER) .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
.requiredOperands(2) .requiredOperandCount(2)
.returnTypeInference( .returnTypeInference(
opBinding -> RowSignatures.makeComplexType( opBinding -> RowSignatures.makeComplexType(
opBinding.getTypeFactory(), opBinding.getTypeFactory(),