SQL OperatorConversions: Introduce.aggregatorBuilder, allow CAST-as-literal. (#14249)

* SQL OperatorConversions: Introduce.aggregatorBuilder, allow CAST-as-literal.

Four main changes:

1) Provide aggregatorBuilder, a more consistent way of defining the
   SqlAggFunction we need for all of our SQL aggregators. The mechanism
   is analogous to the one we already use for SQL functions
   (OperatorConversions.operatorBuilder).

2) Allow CASTs of constants to be considered as "literalOperands". This
   fixes an issue where various of our operators are defined with
   OperandTypes.LITERAL as part of their checkers, which doesn't allow
   casts. However, in these cases we generally _do_ want to allow casts.
   The important piece is that the value must be reducible to a constant,
   not that the SQL text is literally a literal.

3) Update DataSketches SQL aggregators to use the new aggregatorBuilder
   functionality. The main user-visible effect here is [2]: the aggregators
   would now accept, for example, "CAST(0.99 AS DOUBLE)" as a literal
   argument. Other aggregators could be updated in a future patch.

4) Rename "requiredOperands" to "requiredOperandCount", because the
   old name was confusing. (It rhymes with "literalOperands" but the
   arguments mean different things.)

* Adjust method calls.
This commit is contained in:
Gian Merlino 2023-06-23 16:25:04 -07:00 committed by GitHub
parent 1d6c9657ec
commit 3d19b748fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 324 additions and 287 deletions

View File

@ -21,24 +21,30 @@ package org.apache.druid.query.aggregation.datasketches.hll.sql;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import java.util.Collections;
public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{
public static final String NAME = "APPROX_COUNT_DISTINCT_DS_HLL";
private static final SqlAggFunction FUNCTION_INSTANCE = new HllSketchApproxCountDistinctSqlAggFunction();
private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "lgK", "tgtHllType")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)
.operandTypeInference(InferTypes.VARCHAR_1024)
.requiredOperandCount(1)
.literalOperands(1, 2)
.returnTypeNonNull(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();
public HllSketchApproxCountDistinctSqlAggregator()
{
@ -66,30 +72,4 @@ public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlA
) : null
);
}
private static class HllSketchApproxCountDistinctSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE = "'" + NAME + "(column, lgK, tgtHllType)'";
HllSketchApproxCountDistinctSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.BIGINT),
InferTypes.VARCHAR_1024,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)
)
),
SqlFunctionCategory.NUMERIC,
false,
false
);
}
}
}

View File

@ -47,7 +47,7 @@ public class HllSketchEstimateOperatorConversion implements SqlOperatorConversio
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(FUNCTION_NAME))
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.BOOLEAN)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeInference(ReturnTypes.DOUBLE)
.build();

View File

@ -46,7 +46,7 @@ public class HllSketchEstimateWithErrorBoundsOperatorConversion implements SqlOp
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(FUNCTION_NAME))
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.OTHER)
.build();

View File

@ -21,22 +21,29 @@ package org.apache.druid.query.aggregation.datasketches.hll.sql;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import java.util.Collections;
public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{
private static final SqlAggFunction FUNCTION_INSTANCE = new HllSketchSqlAggFunction();
private static final String NAME = "DS_HLL";
private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "lgK", "tgtHllType")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)
.operandTypeInference(InferTypes.VARCHAR_1024)
.requiredOperandCount(1)
.literalOperands(1, 2)
.returnTypeNonNull(SqlTypeName.OTHER)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.build();
public HllSketchObjectSqlAggregator()
{
@ -61,30 +68,4 @@ public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator imp
null
);
}
private static class HllSketchSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE = "'" + NAME + "(column, lgK, tgtHllType)'";
HllSketchSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.OTHER),
InferTypes.VARCHAR_1024,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)
)
),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
);
}
}
}

View File

@ -28,8 +28,6 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.StringUtils;
@ -45,19 +43,27 @@ import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.List;
public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
{
public static final String CTX_APPROX_QUANTILE_DS_MAX_STREAM_LENGTH = "approxQuantileDsMaxStreamLength";
private static final SqlAggFunction FUNCTION_INSTANCE = new DoublesSketchApproxQuantileSqlAggFunction();
private static final String NAME = "APPROX_QUANTILE_DS";
private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "probability", "k")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
.returnTypeNonNull(SqlTypeName.DOUBLE)
.requiredOperandCount(2)
.literalOperands(1, 2)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();
@Override
public SqlAggFunction calciteFunction()
@ -212,34 +218,4 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
DoublesSketchAggregatorFactory.DEFAULT_MAX_STREAM_LENGTH
);
}
private static class DoublesSketchApproxQuantileSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE1 = "'" + NAME + "(column, probability)'";
private static final String SIGNATURE2 = "'" + NAME + "(column, probability, k)'";
DoublesSketchApproxQuantileSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.DOUBLE),
null,
OperandTypes.or(
OperandTypes.and(
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
),
OperandTypes.and(
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
)
),
SqlFunctionCategory.NUMERIC,
false,
false
);
}
}
}

View File

@ -28,8 +28,6 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.StringUtils;
@ -43,17 +41,25 @@ import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.List;
public class DoublesSketchObjectSqlAggregator implements SqlAggregator
{
private static final SqlAggFunction FUNCTION_INSTANCE = new DoublesSketchSqlAggFunction();
private static final String NAME = "DS_QUANTILES_SKETCH";
private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "k")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC)
.returnTypeNonNull(SqlTypeName.OTHER)
.requiredOperandCount(1)
.literalOperands(1)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();
@Override
public SqlAggFunction calciteFunction()
@ -139,30 +145,4 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
null
);
}
private static class DoublesSketchSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE2 = "'" + NAME + "(column, k)'";
DoublesSketchSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.OTHER),
null,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC)
)
),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
);
}
}
}

View File

@ -21,24 +21,30 @@ package org.apache.druid.query.aggregation.datasketches.theta.sql;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import java.util.Collections;
public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{
public static final String NAME = "APPROX_COUNT_DISTINCT_DS_THETA";
private static final SqlAggFunction FUNCTION_INSTANCE = new ThetaSketchSqlAggFunction();
private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "size")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
.operandTypeInference(InferTypes.VARCHAR_1024)
.requiredOperandCount(1)
.literalOperands(1)
.returnTypeNonNull(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.build();
public ThetaSketchApproxCountDistinctSqlAggregator()
{
@ -66,30 +72,4 @@ public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBase
) : null
);
}
private static class ThetaSketchSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE = "'" + NAME + "(column, size)'";
ThetaSketchSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.BIGINT),
InferTypes.VARCHAR_1024,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
)
),
SqlFunctionCategory.NUMERIC,
false,
false
);
}
}
}

View File

@ -21,20 +21,28 @@ package org.apache.druid.query.aggregation.datasketches.theta.sql;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import java.util.Collections;
public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{
private static final SqlAggFunction FUNCTION_INSTANCE = new ThetaSketchObjectSqlAggFunction();
private static final String NAME = "DS_THETA";
private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "size")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
.operandTypeInference(InferTypes.VARCHAR_1024)
.requiredOperandCount(1)
.literalOperands(1)
.returnTypeInference(ThetaSketchSqlOperators.RETURN_TYPE_INFERENCE)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.build();
public ThetaSketchObjectSqlAggregator()
{
@ -59,30 +67,4 @@ public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator
null
);
}
private static class ThetaSketchObjectSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE = "'" + NAME + "(column, size)'";
ThetaSketchObjectSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ThetaSketchSqlOperators.RETURN_TYPE_INFERENCE,
InferTypes.VARCHAR_1024,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
)
),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
);
}
}
}

View File

@ -177,18 +177,19 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest
+ " APPROX_COUNT_DISTINCT(SUBSTRING(dim2, 1, 1)),\n" // on extractionFn, using generic A.C.D.
+ " COUNT(DISTINCT SUBSTRING(dim2, 1, 1) || 'x'),\n" // on expression, using COUNT DISTINCT
+ " APPROX_COUNT_DISTINCT_DS_HLL(hllsketch_dim1, 21, 'HLL_8'),\n" // on native HllSketch column
+ " APPROX_COUNT_DISTINCT_DS_HLL(hllsketch_dim1)\n" // on native HllSketch column
+ " APPROX_COUNT_DISTINCT_DS_HLL(hllsketch_dim1),\n" // on native HllSketch column
+ " APPROX_COUNT_DISTINCT_DS_HLL(hllsketch_dim1, CAST(21 AS BIGINT))\n" // also native column
+ "FROM druid.foo";
final List<Object[]> expectedResults;
if (NullHandling.replaceWithDefault()) {
expectedResults = ImmutableList.of(
new Object[]{6L, 2L, 2L, 1L, 2L, 5L, 5L}
new Object[]{6L, 2L, 2L, 1L, 2L, 5L, 5L, 5L}
);
} else {
expectedResults = ImmutableList.of(
new Object[]{6L, 2L, 2L, 1L, 1L, 5L, 5L}
new Object[]{6L, 2L, 2L, 1L, 1L, 5L, 5L, 5L}
);
}
@ -252,7 +253,8 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest
ROUND
),
new HllSketchMergeAggregatorFactory("a5", "hllsketch_dim1", 21, "HLL_8", null, ROUND),
new HllSketchMergeAggregatorFactory("a6", "hllsketch_dim1", null, null, null, ROUND)
new HllSketchMergeAggregatorFactory("a6", "hllsketch_dim1", null, null, null, ROUND),
new HllSketchMergeAggregatorFactory("a7", "hllsketch_dim1", 21, "HLL_4", null, ROUND)
)
)
.context(QUERY_CONTEXT_DEFAULT)

View File

@ -207,7 +207,7 @@ public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.01),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.5, 64),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.98, 256),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.99),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, CAST(0.99 AS DOUBLE)),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.99) FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.999) FILTER(WHERE dim1 <> 'abc'),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.999) FILTER(WHERE dim1 = 'abc')\n"

View File

@ -42,7 +42,7 @@ public class SleepOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("SLEEP")
.operandTypes(SqlTypeFamily.NUMERIC)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNullable(SqlTypeName.VARCHAR) // always null
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -42,7 +42,7 @@ public class SleepOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("SLEEP")
.operandTypes(SqlTypeFamily.NUMERIC)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNullable(SqlTypeName.VARCHAR) // always null
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -32,6 +32,7 @@ import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.CalciteException;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
@ -49,6 +50,7 @@ import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.apache.calcite.util.Optionality;
import org.apache.calcite.util.Static;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
@ -57,12 +59,14 @@ import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.DruidTypeSystem;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.IntStream;
@ -316,23 +320,33 @@ public class OperatorConversions
* Returns a builder that helps {@link SqlOperatorConversion} implementations create the {@link SqlFunction}
* objects they need to return from {@link SqlOperatorConversion#calciteOperator()}.
*/
public static OperatorBuilder operatorBuilder(final String name)
public static OperatorBuilder<SqlFunction> operatorBuilder(final String name)
{
return new OperatorBuilder(name);
return new OperatorBuilder<>(name);
}
public static class OperatorBuilder
/**
* Returns a builder that helps {@link SqlAggregator} implementations create the {@link SqlAggFunction} objects
* they need to return from {@link SqlAggregator#calciteFunction()}.
*/
public static OperatorBuilder<SqlAggFunction> aggregatorBuilder(final String name)
{
private final String name;
private SqlKind kind = SqlKind.OTHER_FUNCTION;
private SqlReturnTypeInference returnTypeInference;
private SqlFunctionCategory functionCategory = SqlFunctionCategory.USER_DEFINED_FUNCTION;
return new AggregatorBuilder(name);
}
public static class OperatorBuilder<T extends SqlFunction>
{
protected final String name;
protected SqlKind kind = SqlKind.OTHER_FUNCTION;
protected SqlReturnTypeInference returnTypeInference;
protected SqlFunctionCategory functionCategory = SqlFunctionCategory.USER_DEFINED_FUNCTION;
// For operand type checking
private SqlOperandTypeChecker operandTypeChecker;
private List<String> operandNames = Collections.emptyList();
private List<SqlTypeFamily> operandTypes;
private Integer requiredOperands = null;
private int[] literalOperands = null;
private Integer requiredOperandCount;
private int[] literalOperands;
private SqlOperandTypeInference operandTypeInference;
private OperatorBuilder(final String name)
@ -348,7 +362,7 @@ public class OperatorConversions
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeNonNull(final SqlTypeName typeName)
public OperatorBuilder<T> returnTypeNonNull(final SqlTypeName typeName)
{
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
@ -365,7 +379,7 @@ public class OperatorConversions
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeNullable(final SqlTypeName typeName)
public OperatorBuilder<T> returnTypeNullable(final SqlTypeName typeName)
{
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
@ -382,7 +396,7 @@ public class OperatorConversions
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeCascadeNullable(final SqlTypeName typeName)
public OperatorBuilder<T> returnTypeCascadeNullable(final SqlTypeName typeName)
{
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
this.returnTypeInference = ReturnTypes.cascade(ReturnTypes.explicit(typeName), SqlTypeTransforms.TO_NULLABLE);
@ -396,7 +410,7 @@ public class OperatorConversions
* {@link #returnTypeArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be
* used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeArrayWithNullableElements(final SqlTypeName elementTypeName)
public OperatorBuilder<T> returnTypeArrayWithNullableElements(final SqlTypeName elementTypeName)
{
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
@ -413,7 +427,7 @@ public class OperatorConversions
* {@link #returnTypeArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be
* used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeNullableArrayWithNullableElements(final SqlTypeName elementTypeName)
public OperatorBuilder<T> returnTypeNullableArrayWithNullableElements(final SqlTypeName elementTypeName)
{
this.returnTypeInference = ReturnTypes.cascade(
opBinding -> Calcites.createSqlArrayTypeWithNullability(
@ -434,7 +448,7 @@ public class OperatorConversions
* {@link #returnTypeNullableArrayWithNullableElements}, or {@link #returnTypeInference(SqlReturnTypeInference)}
* must be used before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder returnTypeInference(final SqlReturnTypeInference returnTypeInference)
public OperatorBuilder<T> returnTypeInference(final SqlReturnTypeInference returnTypeInference)
{
Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
@ -447,7 +461,7 @@ public class OperatorConversions
*
* The default, if not provided, is {@link SqlFunctionCategory#USER_DEFINED_FUNCTION}.
*/
public OperatorBuilder functionCategory(final SqlFunctionCategory functionCategory)
public OperatorBuilder<T> functionCategory(final SqlFunctionCategory functionCategory)
{
this.functionCategory = functionCategory;
return this;
@ -459,21 +473,32 @@ public class OperatorConversions
* One of {@link #operandTypes(SqlTypeFamily...)} or {@link #operandTypeChecker(SqlOperandTypeChecker)} must be used
* before calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
*/
public OperatorBuilder operandTypeChecker(final SqlOperandTypeChecker operandTypeChecker)
public OperatorBuilder<T> operandTypeChecker(final SqlOperandTypeChecker operandTypeChecker)
{
this.operandTypeChecker = operandTypeChecker;
return this;
}
/**
* Signifies that a function accepts operands with the provided names. This is used to implement
* {@link SqlOperandTypeChecker#getAllowedSignatures(SqlOperator, String)}. If not provided, the
* {@link #operandTypes} are used instead.
*/
public OperatorBuilder<T> operandNames(final String... operandNames)
{
this.operandNames = Arrays.asList(operandNames);
return this;
}
/**
* Signifies that a function accepts operands of type family given by {@param operandTypes}.
*
* May be used in conjunction with {@link #requiredOperands(int)} and {@link #literalOperands(int...)} in order
* May be used in conjunction with {@link #requiredOperandCount(int)} and {@link #literalOperands(int...)} in order
* to further refine operand checking logic.
*
* For deeper control, use {@link #operandTypeChecker(SqlOperandTypeChecker)} instead.
*/
public OperatorBuilder operandTypes(final SqlTypeFamily... operandTypes)
public OperatorBuilder<T> operandTypes(final SqlTypeFamily... operandTypes)
{
this.operandTypes = Arrays.asList(operandTypes);
return this;
@ -489,67 +514,97 @@ public class OperatorConversions
* Must be used in conjunction with {@link #operandTypes(SqlTypeFamily...)}; this method is not compatible with
* {@link #operandTypeChecker(SqlOperandTypeChecker)}.
*/
public OperatorBuilder requiredOperands(final int requiredOperands)
public OperatorBuilder<T> requiredOperandCount(final int requiredOperandCount)
{
this.requiredOperands = requiredOperands;
this.requiredOperandCount = requiredOperandCount;
return this;
}
/**
* Alias for {@link #requiredOperandCount(int)}. Deprecated because it means "operand count" rather than
* "specific operands", and therefore the name can cause confusion with {@link #literalOperands(int...)}. The latter
* really does mean "specific operands".
*/
@Deprecated
@SuppressWarnings("unused") // For compatibility with existing extensions
public OperatorBuilder<T> requiredOperands(final int requiredOperands)
{
return requiredOperandCount(requiredOperands);
}
/**
* Signifies that the operands at positions given by {@code literalOperands} must be literals.
*
* Must be used in conjunction with {@link #operandTypes(SqlTypeFamily...)}; this method is not compatible with
* {@link #operandTypeChecker(SqlOperandTypeChecker)}.
*/
public OperatorBuilder literalOperands(final int... literalOperands)
public OperatorBuilder<T> literalOperands(final int... literalOperands)
{
this.literalOperands = literalOperands;
return this;
}
public OperatorBuilder operandTypeInference(SqlOperandTypeInference operandTypeInference)
public OperatorBuilder<T> operandTypeInference(SqlOperandTypeInference operandTypeInference)
{
this.operandTypeInference = operandTypeInference;
return this;
}
public OperatorBuilder sqlKind(SqlKind kind)
{
this.kind = kind;
return this;
}
/**
* Creates a {@link SqlFunction} from this builder.
*/
public SqlFunction build()
@SuppressWarnings("unchecked")
public T build()
{
final IntSet nullableOperands = buildNullableOperands();
return (T) new SqlFunction(
name,
kind,
Preconditions.checkNotNull(returnTypeInference, "returnTypeInference"),
buildOperandTypeInference(nullableOperands),
buildOperandTypeChecker(nullableOperands),
functionCategory
);
}
protected IntSet buildNullableOperands()
{
// Create "nullableOperands" set including all optional arguments.
final IntSet nullableOperands = new IntArraySet();
if (requiredOperands != null) {
IntStream.range(requiredOperands, operandTypes.size()).forEach(nullableOperands::add);
if (requiredOperandCount != null) {
IntStream.range(requiredOperandCount, operandTypes.size()).forEach(nullableOperands::add);
}
return nullableOperands;
}
final SqlOperandTypeChecker theOperandTypeChecker;
protected SqlOperandTypeChecker buildOperandTypeChecker(final IntSet nullableOperands)
{
if (operandTypeChecker == null) {
theOperandTypeChecker = new DefaultOperandTypeChecker(
return new DefaultOperandTypeChecker(
operandNames,
operandTypes,
requiredOperands == null ? operandTypes.size() : requiredOperands,
requiredOperandCount == null ? operandTypes.size() : requiredOperandCount,
nullableOperands,
literalOperands
);
} else if (operandTypes == null && requiredOperands == null && literalOperands == null) {
theOperandTypeChecker = operandTypeChecker;
} else if (operandNames.isEmpty()
&& operandTypes == null
&& requiredOperandCount == null
&& literalOperands == null) {
return operandTypeChecker;
} else {
throw new ISE(
"Cannot have both 'operandTypeChecker' and 'operandTypes' / 'requiredOperands' / 'literalOperands'"
"Cannot have both 'operandTypeChecker' and "
+ "'operandNames' / 'operandTypes' / 'requiredOperands' / 'literalOperands'"
);
}
}
protected SqlOperandTypeInference buildOperandTypeInference(final IntSet nullableOperands)
{
if (operandTypeInference == null) {
SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands);
operandTypeInference = (callBinding, returnType, types) -> {
return (callBinding, returnType, types) -> {
for (int i = 0; i < types.length; i++) {
// calcite sql validate tries to do bad things to dynamic parameters if the type is inferred to be a string
if (callBinding.operand(i).isA(ImmutableSet.of(SqlKind.DYNAMIC_PARAM))) {
@ -562,15 +617,49 @@ public class OperatorConversions
}
}
};
} else {
return operandTypeInference;
}
return new SqlFunction(
name,
kind,
Preconditions.checkNotNull(returnTypeInference, "returnTypeInference"),
operandTypeInference,
theOperandTypeChecker,
functionCategory
);
}
}
public static class AggregatorBuilder extends OperatorBuilder<SqlAggFunction>
{
public AggregatorBuilder(String name)
{
super(name);
}
/**
* Create a {@link SqlAggFunction} from this builder.
*/
@Override
public SqlAggFunction build()
{
final IntSet nullableOperands = buildNullableOperands();
final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference(nullableOperands);
final SqlOperandTypeChecker operandTypeChecker = buildOperandTypeChecker(nullableOperands);
class DruidSqlAggFunction extends SqlAggFunction
{
public DruidSqlAggFunction()
{
super(
name,
null,
AggregatorBuilder.this.kind,
returnTypeInference,
operandTypeInference,
operandTypeChecker,
functionCategory,
false,
false,
Optionality.FORBIDDEN
);
}
}
return new DruidSqlAggFunction();
}
}
@ -655,6 +744,11 @@ public class OperatorConversions
@VisibleForTesting
static class DefaultOperandTypeChecker implements SqlOperandTypeChecker
{
/**
* Operand names for {@link #getAllowedSignatures(SqlOperator, String)}. May be empty, in which case the
* {@link #operandTypes} are used instead.
*/
private final List<String> operandNames;
private final List<SqlTypeFamily> operandTypes;
private final int requiredOperands;
private final IntSet nullableOperands;
@ -662,6 +756,7 @@ public class OperatorConversions
@VisibleForTesting
DefaultOperandTypeChecker(
final List<String> operandNames,
final List<SqlTypeFamily> operandTypes,
final int requiredOperands,
final IntSet nullableOperands,
@ -669,10 +764,15 @@ public class OperatorConversions
)
{
Preconditions.checkArgument(requiredOperands <= operandTypes.size() && requiredOperands >= 0);
this.operandNames = Preconditions.checkNotNull(operandNames, "operandNames");
this.operandTypes = Preconditions.checkNotNull(operandTypes, "operandTypes");
this.requiredOperands = requiredOperands;
this.nullableOperands = Preconditions.checkNotNull(nullableOperands, "nullableOperands");
if (!operandNames.isEmpty() && operandNames.size() != operandTypes.size()) {
throw new ISE("Operand name count[%s] and type count[%s] must match", operandNames.size(), operandTypes.size());
}
if (literalOperands == null) {
this.literalOperands = IntSets.EMPTY_SET;
} else {
@ -688,8 +788,8 @@ public class OperatorConversions
final SqlNode operand = callBinding.operands().get(i);
if (literalOperands.contains(i)) {
// Verify that 'operand' is a literal.
if (!SqlUtil.isLiteral(operand)) {
// Verify that 'operand' is a literal. Allow CAST, since we can reduce these away later.
if (!SqlUtil.isLiteral(operand, true)) {
return throwOrReturn(
throwOnFailure,
callBinding,
@ -739,7 +839,25 @@ public class OperatorConversions
@Override
public String getAllowedSignatures(SqlOperator op, String opName)
{
return SqlUtil.getAliasedSignature(op, opName, operandTypes);
final List<?> operands = !operandNames.isEmpty() ? operandNames : operandTypes;
final StringBuilder ret = new StringBuilder();
ret.append("'");
ret.append(opName);
ret.append("(");
for (int i = 0; i < operands.size(); i++) {
if (i > 0) {
ret.append(", ");
}
if (i >= requiredOperands) {
ret.append("[");
}
ret.append("<").append(operands.get(i)).append(">");
}
for (int i = requiredOperands; i < operands.size(); i++) {
ret.append("]");
}
ret.append(")'");
return ret.toString();
}
@Override
@ -772,7 +890,7 @@ public class OperatorConversions
{
return new DirectOperatorConversion(
operatorBuilder(sqlOperator)
.requiredOperands(1)
.requiredOperandCount(1)
.operandTypes(SqlTypeFamily.NUMERIC)
.returnTypeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.NUMERIC)
@ -785,7 +903,7 @@ public class OperatorConversions
{
return new DirectOperatorConversion(
operatorBuilder(sqlOperator)
.requiredOperands(2)
.requiredOperandCount(2)
.operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)
.returnTypeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.NUMERIC)
@ -798,7 +916,7 @@ public class OperatorConversions
{
return new DirectOperatorConversion(
operatorBuilder(StringUtils.toUpperCase(sqlOperator))
.requiredOperands(1)
.requiredOperandCount(1)
.operandTypes(SqlTypeFamily.NUMERIC)
.returnTypeNullable(SqlTypeName.DOUBLE)
.functionCategory(SqlFunctionCategory.NUMERIC)

View File

@ -40,7 +40,7 @@ public class BTrimOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1)
.requiredOperandCount(1)
.build();
@Override

View File

@ -80,7 +80,7 @@ public class ContainsOperatorConversion extends DirectOperatorConversion
return OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(functionName))
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(2)
.requiredOperandCount(2)
.literalOperands(1)
.returnTypeNonNull(SqlTypeName.BOOLEAN)
.functionCategory(SqlFunctionCategory.STRING)

View File

@ -67,7 +67,7 @@ public class DateTruncOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("DATE_TRUNC")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP)
.requiredOperands(2)
.requiredOperandCount(2)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -40,7 +40,7 @@ public class LPadOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(2)
.requiredOperandCount(2)
.build();
@Override

View File

@ -40,7 +40,7 @@ public class LTrimOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1)
.requiredOperandCount(1)
.build();
@Override

View File

@ -40,7 +40,7 @@ public class ParseLongOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
.returnTypeCascadeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1)
.requiredOperandCount(1)
.build();
@Override

View File

@ -40,7 +40,7 @@ public class RPadOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(2)
.requiredOperandCount(2)
.build();
@Override

View File

@ -40,7 +40,7 @@ public class RTrimOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.requiredOperands(1)
.requiredOperandCount(1)
.build();
@Override

View File

@ -38,7 +38,7 @@ public class RegexpExtractOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("REGEXP_EXTRACT")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
.requiredOperands(2)
.requiredOperandCount(2)
.literalOperands(1, 2)
.returnTypeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)

View File

@ -44,7 +44,7 @@ public class RegexpLikeOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("REGEXP_LIKE")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(2)
.requiredOperandCount(2)
.literalOperands(1)
.returnTypeNonNull(SqlTypeName.BOOLEAN)
.functionCategory(SqlFunctionCategory.STRING)

View File

@ -35,7 +35,7 @@ public class RoundOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("ROUND")
.operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeInference(ReturnTypes.ARG0)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();

View File

@ -43,7 +43,7 @@ public class SubstringOperatorConversion implements SqlOperatorConversion
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
.functionCategory(SqlFunctionCategory.STRING)
.returnTypeInference(ReturnTypes.ARG0)
.requiredOperands(2)
.requiredOperandCount(2)
.build();
@Override

View File

@ -31,7 +31,7 @@ public class TextcatOperatorConversion extends DirectOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("textcat")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(2)
.requiredOperandCount(2)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.STRING)
.build();

View File

@ -41,7 +41,7 @@ public class TimeCeilOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIME_CEIL")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)
.requiredOperands(2)
.requiredOperandCount(2)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -43,7 +43,7 @@ public class TimeExtractOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIME_EXTRACT")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(2)
.requiredOperandCount(2)
.returnTypeCascadeNullable(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -56,7 +56,7 @@ public class TimeFloorOperatorConversion implements SqlOperatorConversion
public static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(SQL_FUNCTION_NAME)
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)
.requiredOperands(2)
.requiredOperandCount(2)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -46,7 +46,7 @@ public class TimeFormatOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIME_FORMAT")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeCascadeNullable(SqlTypeName.VARCHAR)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -43,7 +43,7 @@ public class TimeParseOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIME_PARSE")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -43,7 +43,7 @@ public class TimeShiftOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TIME_SHIFT")
.operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.requiredOperands(3)
.requiredOperandCount(3)
.returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
.functionCategory(SqlFunctionCategory.TIMEDATE)
.build();

View File

@ -37,7 +37,7 @@ public class TruncateOperatorConversion implements SqlOperatorConversion
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("TRUNCATE")
.operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeInference(ReturnTypes.ARG0)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();

View File

@ -51,6 +51,7 @@ import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
@RunWith(Enclosed.class)
@ -65,6 +66,7 @@ public class OperatorConversionsTest
public void testGetOperandCountRange()
{
SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
Collections.emptyList(),
ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
2,
IntSets.EMPTY_SET,
@ -79,6 +81,7 @@ public class OperatorConversionsTest
public void testIsOptional()
{
SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
Collections.emptyList(),
ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
2,
IntSets.EMPTY_SET,
@ -95,7 +98,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testAllowFullOperands")
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -119,7 +122,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testRequiredOperandsOnly")
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -140,7 +143,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testLiteralOperandCheckLiteral")
.operandTypes(SqlTypeFamily.INTEGER)
.requiredOperands(1)
.requiredOperandCount(1)
.literalOperands(0)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
@ -162,7 +165,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testLiteralOperandCheckLiteralThrow")
.operandTypes(SqlTypeFamily.INTEGER)
.requiredOperands(1)
.requiredOperandCount(1)
.literalOperands(0)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
@ -184,7 +187,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testAnyTypeOperand")
.operandTypes(SqlTypeFamily.ANY)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -205,7 +208,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testCastableFromDatetimeFamilyToTimestamp")
.operandTypes(SqlTypeFamily.DATETIME)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -235,7 +238,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -259,7 +262,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testNullLiteralForNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -283,7 +286,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableNonnull")
.operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -304,7 +307,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableCascade")
.operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeCascadeNullable(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -325,7 +328,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableNonnull")
.operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNullable(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -346,7 +349,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNonNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -372,7 +375,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testNullLiteralForNonNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -398,7 +401,7 @@ public class OperatorConversionsTest
SqlFunction function = OperatorConversions
.operatorBuilder("testNonCastableType")
.operandTypes(SqlTypeFamily.CURSOR, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(2)
.requiredOperandCount(2)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
@ -418,6 +421,41 @@ public class OperatorConversionsTest
);
}
@Test
public void testSignatureWithNames()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testSignatureWithNames")
.operandNames("x", "y", "z")
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE, SqlTypeFamily.ANY)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertEquals(
"'testSignatureWithNames(<x>, [<y>, [<z>]])'",
typeChecker.getAllowedSignatures(function, function.getName())
);
}
@Test
public void testSignatureWithoutNames()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testSignatureWithoutNames")
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE, SqlTypeFamily.ANY)
.requiredOperandCount(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertEquals(
"'testSignatureWithoutNames(<INTEGER>, [<DATE>, [<ANY>]])'",
typeChecker.getAllowedSignatures(function, function.getName())
);
}
private static SqlCallBinding mockCallBinding(
SqlFunction function,
List<OperandSpec> actualOperands

View File

@ -68,7 +68,7 @@ public class DruidOperatorTableTest
final SqlOperator operator1 = OperatorConversions
.operatorBuilder("FOO")
.operandTypes(SqlTypeFamily.ANY)
.requiredOperands(0)
.requiredOperandCount(0)
.returnTypeInference(
opBinding -> RowSignatures.makeComplexType(
opBinding.getTypeFactory(),

View File

@ -69,7 +69,7 @@ public class DruidRexExecutorTest extends InitializedNullHandlingTest
private static final SqlOperator OPERATOR = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase("hyper_unique"))
.operandTypes(SqlTypeFamily.ANY)
.requiredOperands(0)
.requiredOperandCount(0)
.returnTypeInference(
opBinding -> RowSignatures.makeComplexType(
opBinding.getTypeFactory(),

View File

@ -136,7 +136,7 @@ public class InformationSchemaTest extends BaseCalciteQueryTest
Assert.assertNotNull(rows);
Assert.assertEquals("There should be exactly 2 rows; any non-function syntax operator should get filtered out",
2, rows.size());
Object[] expectedRow1 = {"druid", "INFORMATION_SCHEMA", "FOO", "FUNCTION", "NO", "'FOO(<ANY>)'"};
Object[] expectedRow1 = {"druid", "INFORMATION_SCHEMA", "FOO", "FUNCTION", "NO", "'FOO([<ANY>])'"};
Assert.assertTrue(rows.stream().anyMatch(row -> Arrays.equals(row, expectedRow1)));
Object[] expectedRow2 = {"druid", "INFORMATION_SCHEMA", "BAR", "FUNCTION", "NO", "'BAR(<INTEGER>, <INTEGER>)'"};
@ -166,7 +166,7 @@ public class InformationSchemaTest extends BaseCalciteQueryTest
final SqlOperator operator1 = OperatorConversions
.operatorBuilder("FOO")
.operandTypes(SqlTypeFamily.ANY)
.requiredOperands(0)
.requiredOperandCount(0)
.returnTypeInference(
opBinding -> RowSignatures.makeComplexType(
opBinding.getTypeFactory(),
@ -182,7 +182,7 @@ public class InformationSchemaTest extends BaseCalciteQueryTest
.operatorBuilder("BAR")
.operandTypes(SqlTypeFamily.NUMERIC)
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
.requiredOperands(2)
.requiredOperandCount(2)
.returnTypeInference(
opBinding -> RowSignatures.makeComplexType(
opBinding.getTypeFactory(),