diff --git a/extensions-contrib/compressed-bigdecimal/src/main/java/org/apache/druid/compressedbigdecimal/CompressedBigDecimalSqlAggregatorBase.java b/extensions-contrib/compressed-bigdecimal/src/main/java/org/apache/druid/compressedbigdecimal/CompressedBigDecimalSqlAggregatorBase.java index 4a61f0271ee..a6c23551598 100644 --- a/extensions-contrib/compressed-bigdecimal/src/main/java/org/apache/druid/compressedbigdecimal/CompressedBigDecimalSqlAggregatorBase.java +++ b/extensions-contrib/compressed-bigdecimal/src/main/java/org/apache/druid/compressedbigdecimal/CompressedBigDecimalSqlAggregatorBase.java @@ -21,8 +21,6 @@ package org.apache.druid.compressedbigdecimal; import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -36,12 +34,12 @@ import org.apache.calcite.util.Optionality; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -71,12 +69,10 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) @@ -88,13 +84,8 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg // fetch sum column expression DruidExpression sumColumn = Expressions.toDruidExpression( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (sumColumn == null) { @@ -114,12 +105,7 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg Integer size = null; if (aggregateCall.getArgList().size() >= 2) { - RexNode sizeArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + RexNode sizeArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); size = ((Number) RexLiteral.value(sizeArg)).intValue(); } @@ -128,12 +114,7 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg Integer scale = null; if (aggregateCall.getArgList().size() >= 3) { - RexNode scaleArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + RexNode scaleArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); scale = ((Number) RexLiteral.value(scaleArg)).intValue(); } @@ -141,12 +122,7 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg Boolean useStrictNumberParsing = null; if (aggregateCall.getArgList().size() >= 4) { - RexNode useStrictNumberParsingArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(3) - ); + RexNode useStrictNumberParsingArg = inputAccessor.getField(aggregateCall.getArgList().get(3)); useStrictNumberParsing = RexLiteral.booleanValue(useStrictNumberParsingArg); } diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java index ca0a4acc603..ebb6c7f4b14 100644 --- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java +++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java @@ -20,8 +20,6 @@ package org.apache.druid.query.aggregation.tdigestsketch.sql; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -36,13 +34,12 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorFactory; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchUtils; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -63,25 +60,18 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { - final RexNode inputOperand = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ); + final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0)); final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), inputOperand ); if (input == null) { @@ -93,12 +83,7 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator Integer compression = TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION; if (aggregateCall.getArgList().size() > 1) { - RexNode compressionOperand = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + RexNode compressionOperand = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!compressionOperand.isA(SqlKind.LITERAL)) { // compressionOperand must be a literal in order to plan. return null; diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java index 379e889d383..ee63444f6d7 100644 --- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java +++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java @@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.tdigestsketch.sql; import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -39,13 +37,12 @@ import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorF import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchToQuantilePostAggregator; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchUtils; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -66,12 +63,10 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) @@ -79,13 +74,8 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator // This is expected to be a tdigest sketch final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (input == null) { return null; @@ -95,12 +85,7 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator final String sketchName = StringUtils.format("%s:agg", name); // this is expected to be quantile fraction - final RexNode quantileArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode quantileArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!quantileArg.isA(SqlKind.LITERAL)) { // Quantile must be a literal in order to plan. @@ -110,12 +95,7 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator final double quantile = ((Number) RexLiteral.value(quantileArg)).floatValue(); Integer compression = TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION; if (aggregateCall.getArgList().size() > 2) { - final RexNode compressionArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + final RexNode compressionArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); compression = ((Number) RexLiteral.value(compressionArg)).intValue(); } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java index c6dd3e7afa0..d221b72ac1c 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java @@ -20,9 +20,7 @@ package org.apache.druid.query.aggregation.datasketches.hll.sql; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlKind; @@ -36,7 +34,6 @@ import org.apache.druid.query.aggregation.datasketches.hll.HllSketchMergeAggrega import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; @@ -44,6 +41,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -66,38 +64,26 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { // Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access // for string columns. - final RexNode columnRexNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ); + final RexNode columnRexNode = inputAccessor.getField(aggregateCall.getArgList().get(0)); - final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, rowSignature, columnRexNode); + final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), columnRexNode); if (columnArg == null) { return null; } final int logK; if (aggregateCall.getArgList().size() >= 2) { - final RexNode logKarg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode logKarg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!logKarg.isA(SqlKind.LITERAL)) { // logK must be a literal in order to plan. @@ -111,12 +97,7 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator final String tgtHllType; if (aggregateCall.getArgList().size() >= 3) { - final RexNode tgtHllTypeArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + final RexNode tgtHllTypeArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); if (!tgtHllTypeArg.isA(SqlKind.LITERAL)) { // tgtHllType must be a literal in order to plan. @@ -132,9 +113,10 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; if (columnArg.isDirectColumnAccess() - && rowSignature.getColumnType(columnArg.getDirectColumn()) - .map(type -> type.is(ValueType.COMPLEX)) - .orElse(false)) { + && inputAccessor.getInputRowSignature() + .getColumnType(columnArg.getDirectColumn()) + .map(type -> type.is(ValueType.COMPLEX)) + .orElse(false)) { aggregatorFactory = new HllSketchMergeAggregatorFactory( aggregatorName, columnArg.getDirectColumn(), diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java index 6c1b5720af4..08c7a1b123f 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java @@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.datasketches.quantiles.sql; import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -37,14 +35,13 @@ import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchAg import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchToQuantilePostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -75,25 +72,18 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (input == null) { return null; @@ -101,12 +91,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator final AggregatorFactory aggregatorFactory; final String histogramName = StringUtils.format("%s:agg", name); - final RexNode probabilityArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!probabilityArg.isA(SqlKind.LITERAL)) { // Probability must be a literal in order to plan. @@ -117,12 +102,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator final int k; if (aggregateCall.getArgList().size() >= 3) { - final RexNode resolutionArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); if (!resolutionArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java index 049e1284a91..8331ab72064 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java @@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.datasketches.quantiles.sql; import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -35,14 +33,13 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.datasketches.SketchQueryContext; import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchAggregatorFactory; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -71,25 +68,18 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (input == null) { return null; @@ -100,12 +90,7 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator final int k; if (aggregateCall.getArgList().size() >= 2) { - final RexNode resolutionArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!resolutionArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java index 6564b276c97..bf35cd665ae 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java @@ -20,9 +20,7 @@ package org.apache.druid.query.aggregation.datasketches.theta.sql; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlKind; @@ -34,7 +32,6 @@ import org.apache.druid.query.aggregation.datasketches.theta.SketchMergeAggregat import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; @@ -42,6 +39,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -60,38 +58,26 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { // Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access // for string columns. - final RexNode columnRexNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ); + final RexNode columnRexNode = inputAccessor.getField(aggregateCall.getArgList().get(0)); - final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, rowSignature, columnRexNode); + final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), columnRexNode); if (columnArg == null) { return null; } final int sketchSize; if (aggregateCall.getArgList().size() >= 2) { - final RexNode sketchSizeArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode sketchSizeArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!sketchSizeArg.isA(SqlKind.LITERAL)) { // logK must be a literal in order to plan. @@ -107,9 +93,10 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; if (columnArg.isDirectColumnAccess() - && rowSignature.getColumnType(columnArg.getDirectColumn()) - .map(type -> type.is(ValueType.COMPLEX)) - .orElse(false)) { + && inputAccessor.getInputRowSignature() + .getColumnType(columnArg.getDirectColumn()) + .map(type -> type.is(ValueType.COMPLEX)) + .orElse(false)) { aggregatorFactory = new SketchMergeAggregatorFactory( aggregatorName, columnArg.getDirectColumn(), diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/sql/ArrayOfDoublesSketchSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/sql/ArrayOfDoublesSketchSqlAggregator.java index 9d6ddac89a8..a9b1aaa627d 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/sql/ArrayOfDoublesSketchSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/sql/ArrayOfDoublesSketchSqlAggregator.java @@ -20,9 +20,7 @@ package org.apache.druid.query.aggregation.datasketches.tuple.sql; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -38,7 +36,6 @@ import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketc import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; @@ -46,6 +43,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -69,12 +67,10 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) @@ -86,12 +82,7 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator final int nominalEntries; final int metricExpressionEndIndex; final int lastArgIndex = argList.size() - 1; - final RexNode potentialNominalEntriesArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - argList.get(lastArgIndex) - ); + final RexNode potentialNominalEntriesArg = inputAccessor.getField(argList.get(lastArgIndex)); if (potentialNominalEntriesArg.isA(SqlKind.LITERAL) && RexLiteral.value(potentialNominalEntriesArg) instanceof Number) { @@ -107,16 +98,11 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator for (int i = 0; i <= metricExpressionEndIndex; i++) { final String fieldName; - final RexNode columnRexNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - argList.get(i) - ); + final RexNode columnRexNode = inputAccessor.getField(argList.get(i)); final DruidExpression columnArg = Expressions.toDruidExpression( plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), columnRexNode ); if (columnArg == null) { @@ -124,9 +110,10 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator } if (columnArg.isDirectColumnAccess() && - rowSignature.getColumnType(columnArg.getDirectColumn()) - .map(type -> type.is(ValueType.COMPLEX)) - .orElse(false)) { + inputAccessor.getInputRowSignature() + .getColumnType(columnArg.getDirectColumn()) + .map(type -> type.is(ValueType.COMPLEX)) + .orElse(false)) { fieldName = columnArg.getDirectColumn(); } else { final RelDataType dataType = columnRexNode.getType(); diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java index 0ec265595e1..6a1ca49067e 100644 --- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java +++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java @@ -20,8 +20,6 @@ package org.apache.druid.query.aggregation.bloom.sql; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -38,13 +36,13 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -65,25 +63,18 @@ public class BloomFilterSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { - final RexNode inputOperand = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ); + final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0)); final DruidExpression input = Expressions.toDruidExpression( plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), inputOperand ); if (input == null) { @@ -92,12 +83,7 @@ public class BloomFilterSqlAggregator implements SqlAggregator final AggregatorFactory aggregatorFactory; final String aggName = StringUtils.format("%s:agg", name); - final RexNode maxNumEntriesOperand = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode maxNumEntriesOperand = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!maxNumEntriesOperand.isA(SqlKind.LITERAL)) { // maxNumEntriesOperand must be a literal in order to plan. diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java index 3f0bd14f844..fdc61796c4d 100644 --- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java @@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.histogram.sql; import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -38,13 +36,12 @@ import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogram; import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogramAggregatorFactory; import org.apache.druid.query.aggregation.histogram.QuantilePostAggregator; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -65,25 +62,18 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (input == null) { return null; @@ -91,12 +81,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator final AggregatorFactory aggregatorFactory; final String histogramName = StringUtils.format("%s:agg", name); - final RexNode probabilityArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!probabilityArg.isA(SqlKind.LITERAL)) { // Probability must be a literal in order to plan. @@ -107,12 +92,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator final int numBuckets; if (aggregateCall.getArgList().size() >= 3) { - final RexNode numBucketsArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + final RexNode numBucketsArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); if (!numBucketsArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. @@ -126,12 +106,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator final double lowerLimit; if (aggregateCall.getArgList().size() >= 4) { - final RexNode lowerLimitArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(3) - ); + final RexNode lowerLimitArg = inputAccessor.getField(aggregateCall.getArgList().get(3)); if (!lowerLimitArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. @@ -145,12 +120,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator final double upperLimit; if (aggregateCall.getArgList().size() >= 5) { - final RexNode upperLimitArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(4) - ); + final RexNode upperLimitArg = inputAccessor.getField(aggregateCall.getArgList().get(4)); if (!upperLimitArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. @@ -164,12 +134,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator final FixedBucketsHistogram.OutlierHandlingMode outlierHandlingMode; if (aggregateCall.getArgList().size() >= 6) { - final RexNode outlierHandlingModeArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(5) - ); + final RexNode outlierHandlingModeArg = inputAccessor.getField(aggregateCall.getArgList().get(5)); if (!outlierHandlingModeArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java index 9ba7604d709..a3fe8dc5458 100644 --- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java @@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.histogram.sql; import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -39,14 +37,13 @@ import org.apache.druid.query.aggregation.histogram.ApproximateHistogramAggregat import org.apache.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory; import org.apache.druid.query.aggregation.histogram.QuantilePostAggregator; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -67,25 +64,18 @@ public class QuantileSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (input == null) { return null; @@ -93,12 +83,7 @@ public class QuantileSqlAggregator implements SqlAggregator final AggregatorFactory aggregatorFactory; final String histogramName = StringUtils.format("%s:agg", name); - final RexNode probabilityArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!probabilityArg.isA(SqlKind.LITERAL)) { // Probability must be a literal in order to plan. @@ -109,12 +94,7 @@ public class QuantileSqlAggregator implements SqlAggregator final int resolution; if (aggregateCall.getArgList().size() >= 3) { - final RexNode resolutionArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); if (!resolutionArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. @@ -170,7 +150,10 @@ public class QuantileSqlAggregator implements SqlAggregator // No existing match found. Create a new one. if (input.isDirectColumnAccess()) { - if (rowSignature.getColumnType(input.getDirectColumn()).map(type -> type.is(ValueType.COMPLEX)).orElse(false)) { + if (inputAccessor.getInputRowSignature() + .getColumnType(input.getDirectColumn()) + .map(type -> type.is(ValueType.COMPLEX)) + .orElse(false)) { aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory( histogramName, input.getDirectColumn(), diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java index ee8c469c3b8..b2ed565d627 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java @@ -21,9 +21,7 @@ package org.apache.druid.query.aggregation.variance.sql; import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; @@ -40,15 +38,14 @@ import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.table.RowSignatures; @@ -77,25 +74,19 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { - final RexNode inputOperand = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ); + final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0)); + final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), inputOperand ); if (input == null) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregations.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregations.java index 3a3e43dd7b8..5c06332a9bc 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregations.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregations.java @@ -20,14 +20,13 @@ package org.apache.druid.sql.calcite.aggregation; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import javax.annotation.Nullable; import java.util.List; @@ -48,28 +47,24 @@ public class Aggregations * * 1) They can take direct field accesses or expressions as inputs. * 2) They cannot implicitly cast strings to numbers when using a direct field access. - * * @param plannerContext SQL planner context - * @param rowSignature input row signature * @param call aggregate call object - * @param project project that should be applied before aggregation; may be null + * @param inputAccessor gives access to input fields and schema * * @return list of expressions corresponding to aggregator arguments, or null if any cannot be translated */ @Nullable public static List getArgumentsForSimpleAggregator( - final RexBuilder rexBuilder, final PlannerContext plannerContext, - final RowSignature rowSignature, final AggregateCall call, - @Nullable final Project project + final InputAccessor inputAccessor ) { final List args = call .getArgList() .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .map(rexNode -> toDruidExpressionForNumericAggregator(plannerContext, rowSignature, rexNode)) + .map(i -> inputAccessor.getField(i)) + .map(rexNode -> toDruidExpressionForNumericAggregator(plannerContext, inputAccessor.getInputRowSignature(), rexNode)) .collect(Collectors.toList()); if (args.stream().noneMatch(Objects::isNull)) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java index 0ff7972657e..eceb4ebbf80 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java @@ -20,8 +20,6 @@ package org.apache.druid.sql.calcite.aggregation; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; @@ -30,8 +28,8 @@ import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.Optionality; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -66,24 +64,20 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { return delegate.toDruidAggregation( plannerContext, - rowSignature, virtualColumnRegistry, - rexBuilder, name, aggregateCall, - project, + inputAccessor, existingAggregations, finalizeAggregations ); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/SqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/SqlAggregator.java index d21f6ebb75a..ec494a2fec4 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/SqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/SqlAggregator.java @@ -25,6 +25,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlAggFunction; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -42,6 +43,44 @@ public interface SqlAggregator */ SqlAggFunction calciteFunction(); + /** + * Returns a Druid Aggregation corresponding to a SQL {@link AggregateCall}. This method should ignore filters; + * they will be applied to your aggregator in a later step. + * + * @param plannerContext SQL planner context + * @param virtualColumnRegistry re-usable virtual column references + * @param name desired output name of the aggregation + * @param aggregateCall aggregate call object + * @param inputAccessor gives access to input fields and schema + * @param existingAggregations existing aggregations for this query; useful for re-using aggregations. May be safely + * ignored if you do not want to re-use existing aggregations. + * @param finalizeAggregations true if this query should include explicit finalization for all of its + * aggregators, where required. This is set for subqueries where Druid's native query + * layer does not do this automatically. + * @return aggregation, or null if the call cannot be translated + */ + @Nullable + default Aggregation toDruidAggregation( + PlannerContext plannerContext, + VirtualColumnRegistry virtualColumnRegistry, + String name, + AggregateCall aggregateCall, + InputAccessor inputAccessor, + List existingAggregations, + boolean finalizeAggregations + ) + { + return toDruidAggregation(plannerContext, + inputAccessor.getInputRowSignature(), + virtualColumnRegistry, + inputAccessor.getRexBuilder(), + name, + aggregateCall, + inputAccessor.getProject(), + existingAggregations, + finalizeAggregations); + } + /** * Returns a Druid Aggregation corresponding to a SQL {@link AggregateCall}. This method should ignore filters; * they will be applied to your aggregator in a later step. @@ -62,7 +101,7 @@ public interface SqlAggregator * @return aggregation, or null if the call cannot be translated */ @Nullable - Aggregation toDruidAggregation( + default Aggregation toDruidAggregation( PlannerContext plannerContext, RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, @@ -72,5 +111,8 @@ public interface SqlAggregator Project project, List existingAggregations, boolean finalizeAggregations - ); + ) + { + throw new RuntimeException("unimplemented fallback method!"); + } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java index ed6652181eb..be21701d1eb 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java @@ -21,8 +21,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import com.google.common.collect.ImmutableSet; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -39,18 +37,17 @@ import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.segment.VirtualColumn; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; import java.util.List; -import java.util.stream.Collectors; public class ArrayConcatSqlAggregator implements SqlAggregator { @@ -67,21 +64,15 @@ public class ArrayConcatSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { - final List arguments = aggregateCall - .getArgList() - .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .collect(Collectors.toList()); + final List arguments = inputAccessor.getFields(aggregateCall.getArgList()); Integer maxSizeBytes = null; if (arguments.size() > 1) { @@ -92,7 +83,7 @@ public class ArrayConcatSqlAggregator implements SqlAggregator } maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); } - final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, arguments.get(0)); + final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0)); final ExprMacroTable macroTable = plannerContext.getPlannerToolbox().exprMacroTable(); final String fieldName; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java index 5136ed3c947..9af5210905e 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java @@ -21,9 +21,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import com.google.common.collect.ImmutableSet; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -41,18 +39,17 @@ import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; import java.util.List; -import java.util.stream.Collectors; public class ArraySqlAggregator implements SqlAggregator { @@ -69,21 +66,16 @@ public class ArraySqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { - final List arguments = aggregateCall - .getArgList() - .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .collect(Collectors.toList()); + final List arguments = + inputAccessor.getFields(aggregateCall.getArgList()); Integer maxSizeBytes = null; if (arguments.size() > 1) { @@ -94,7 +86,7 @@ public class ArraySqlAggregator implements SqlAggregator } maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); } - final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, arguments.get(0)); + final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0)); if (arg == null) { // can't translate argument return null; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java index a938bdca0b8..3814f8d9ad8 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java @@ -22,8 +22,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -33,14 +31,13 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -58,23 +55,19 @@ public class AvgSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { final List arguments = Aggregations.getArgumentsForSimpleAggregator( - rexBuilder, plannerContext, - rowSignature, aggregateCall, - project + inputAccessor ); if (arguments == null) { @@ -85,11 +78,11 @@ public class AvgSqlAggregator implements SqlAggregator final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory( countName, plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), virtualColumnRegistry, - rexBuilder, + inputAccessor.getRexBuilder(), aggregateCall, - project + inputAccessor ); final DruidExpression arg = Iterables.getOnlyElement(arguments); @@ -108,12 +101,8 @@ public class AvgSqlAggregator implements SqlAggregator if (arg.isDirectColumnAccess()) { fieldName = arg.getDirectColumn(); } else { - final RexNode resolutionArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - Iterables.getOnlyElement(aggregateCall.getArgList()) - ); + final RexNode resolutionArg = inputAccessor.getField( + Iterables.getOnlyElement(aggregateCall.getArgList())); fieldName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(arg, resolutionArg.getType()); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java index d8758141dfb..a5c7fb61cff 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java @@ -21,8 +21,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import com.google.common.collect.ImmutableSet; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; @@ -41,12 +39,12 @@ import org.apache.druid.query.filter.NotDimFilter; import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -122,12 +120,10 @@ public class BitwiseSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) @@ -135,8 +131,8 @@ public class BitwiseSqlAggregator implements SqlAggregator final List arguments = aggregateCall .getArgList() .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode)) + .map(i -> inputAccessor.getField(i)) + .map(rexNode -> Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode)) .collect(Collectors.toList()); if (arguments.stream().anyMatch(Objects::isNull)) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java index e4dedd95ce2..699c7a8d1c6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java @@ -22,9 +22,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; @@ -42,7 +40,6 @@ import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFact import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; @@ -50,6 +47,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -72,26 +70,20 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { // Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access // for string columns. - final RexNode rexNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - Iterables.getOnlyElement(aggregateCall.getArgList()) - ); + final RexNode rexNode = inputAccessor.getField( + Iterables.getOnlyElement(aggregateCall.getArgList())); - final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, rexNode); + final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode); if (arg == null) { return null; } @@ -100,7 +92,10 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; if (arg.isDirectColumnAccess() - && rowSignature.getColumnType(arg.getDirectColumn()).map(type -> type.is(ValueType.COMPLEX)).orElse(false)) { + && inputAccessor.getInputRowSignature() + .getColumnType(arg.getDirectColumn()) + .map(type -> type.is(ValueType.COMPLEX)) + .orElse(false)) { aggregatorFactory = new HyperUniquesAggregatorFactory(aggregatorName, arg.getDirectColumn(), false, true); } else { final RelDataType dataType = rexNode.getType(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java index edc7e3ce50a..c28ac8eebb2 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java @@ -22,7 +22,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -40,6 +39,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -69,15 +69,10 @@ public class CountSqlAggregator implements SqlAggregator final VirtualColumnRegistry virtualColumnRegistry, final RexBuilder rexBuilder, final AggregateCall aggregateCall, - final Project project + final InputAccessor inputAccessor ) { - final RexNode rexNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - Iterables.getOnlyElement(aggregateCall.getArgList()) - ); + final RexNode rexNode = inputAccessor.getField(Iterables.getOnlyElement(aggregateCall.getArgList())); if (rexNode.getType().isNullable()) { final DimFilter nonNullFilter = Expressions.toFilter( @@ -102,28 +97,25 @@ public class CountSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { final List args = Aggregations.getArgumentsForSimpleAggregator( - rexBuilder, plannerContext, - rowSignature, aggregateCall, - project + inputAccessor ); if (args == null) { return null; } + // FIXME: is-all-literal if (args.isEmpty()) { // COUNT(*) return Aggregation.create(new CountAggregatorFactory(name)); @@ -132,12 +124,10 @@ public class CountSqlAggregator implements SqlAggregator if (plannerContext.getPlannerConfig().isUseApproximateCountDistinct()) { return approxCountDistinctAggregator.toDruidAggregation( plannerContext, - rowSignature, virtualColumnRegistry, - rexBuilder, name, aggregateCall, - project, + inputAccessor, existingAggregations, finalizeAggregations ); @@ -150,11 +140,11 @@ public class CountSqlAggregator implements SqlAggregator AggregatorFactory theCount = createCountAggregatorFactory( name, plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), virtualColumnRegistry, - rexBuilder, + inputAccessor.getRexBuilder(), aggregateCall, - project + inputAccessor ); return Aggregation.create(theCount); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java index 5f1b3c3228d..abaeede9948 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java @@ -20,9 +20,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -53,19 +51,18 @@ import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; public class EarliestLatestAnySqlAggregator implements SqlAggregator { @@ -180,23 +177,17 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { - final List rexNodes = aggregateCall - .getArgList() - .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .collect(Collectors.toList()); + final List rexNodes = inputAccessor.getFields(aggregateCall.getArgList()); - final List args = Expressions.toDruidExpressions(plannerContext, rowSignature, rexNodes); + final List args = Expressions.toDruidExpressions(plannerContext, inputAccessor.getInputRowSignature(), rexNodes); if (args == null) { return null; @@ -216,7 +207,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator final String fieldName = getColumnName(plannerContext, virtualColumnRegistry, args.get(0), rexNodes.get(0)); - if (!rowSignature.contains(ColumnHolder.TIME_COLUMN_NAME) && (aggregatorType == AggregatorType.LATEST || aggregatorType == AggregatorType.EARLIEST)) { + if (!inputAccessor.getInputRowSignature().contains(ColumnHolder.TIME_COLUMN_NAME) + && (aggregatorType == AggregatorType.LATEST || aggregatorType == AggregatorType.EARLIEST)) { // This code is being run as part of the exploratory volcano planner, currently, the definition of these // aggregators does not tell Calcite that they depend on a __time column being in existence, instead we are // allowing the volcano planner to explore paths that put projections which eliminate the time column in between diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java index 95b70e1f1e5..c12be459cf5 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java @@ -20,8 +20,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -38,19 +36,18 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; public class EarliestLatestBySqlAggregator implements SqlAggregator { @@ -76,23 +73,17 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { - final List rexNodes = aggregateCall - .getArgList() - .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .collect(Collectors.toList()); + final List rexNodes = inputAccessor.getFields(aggregateCall.getArgList()); - final List args = Expressions.toDruidExpressions(plannerContext, rowSignature, rexNodes); + final List args = Expressions.toDruidExpressions(plannerContext, inputAccessor.getInputRowSignature(), rexNodes); if (args == null) { return null; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java index 156c3995c6f..ec829df11d7 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java @@ -22,7 +22,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -34,6 +33,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -53,24 +53,22 @@ public class GroupingSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, - List existingAggregations, - boolean finalizeAggregations + final InputAccessor inputAccessor, + final List existingAggregations, + final boolean finalizeAggregations ) { List arguments = aggregateCall.getArgList() .stream() .map(i -> getColumnName( plannerContext, - rowSignature, - project, + inputAccessor.getInputRowSignature(), + inputAccessor.getProject(), virtualColumnRegistry, - rexBuilder.getTypeFactory(), + inputAccessor.getRexBuilder().getTypeFactory(), i )) .filter(Objects::nonNull) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LiteralSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LiteralSqlAggregator.java index 0eb2c1085c0..6e7de762b23 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LiteralSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LiteralSqlAggregator.java @@ -21,18 +21,16 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlInternalOperators; import org.apache.druid.query.aggregation.post.ExpressionPostAggregator; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -59,12 +57,10 @@ public class LiteralSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) @@ -73,7 +69,7 @@ public class LiteralSqlAggregator implements SqlAggregator return null; } final RexNode literal = aggregateCall.rexList.get(0); - final DruidExpression expr = Expressions.toDruidExpression(plannerContext, rowSignature, literal); + final DruidExpression expr = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), literal); if (expr == null) { return null; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SimpleSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SimpleSqlAggregator.java index 01782668663..5da064c285d 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SimpleSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SimpleSqlAggregator.java @@ -21,18 +21,16 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import com.google.common.collect.Iterables; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.druid.error.DruidException; import org.apache.druid.error.InvalidSqlInput; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -57,12 +55,10 @@ public abstract class SimpleSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) @@ -72,11 +68,9 @@ public abstract class SimpleSqlAggregator implements SqlAggregator } final List arguments = Aggregations.getArgumentsForSimpleAggregator( - rexBuilder, plannerContext, - rowSignature, aggregateCall, - project + inputAccessor ); if (arguments == null) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java index b391100ff3a..7c1389de3fe 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java @@ -21,9 +21,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import com.google.common.collect.ImmutableSet; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -47,13 +45,13 @@ import org.apache.druid.query.filter.NotDimFilter; import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.table.RowSignatures; @@ -89,12 +87,10 @@ public class StringSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) @@ -102,20 +98,15 @@ public class StringSqlAggregator implements SqlAggregator final List arguments = aggregateCall .getArgList() .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode)) + .map(i -> inputAccessor.getField(i)) + .map(rexNode -> Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode)) .collect(Collectors.toList()); if (arguments.stream().anyMatch(Objects::isNull)) { return null; } - RexNode separatorNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + RexNode separatorNode = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!separatorNode.isA(SqlKind.LITERAL)) { // separator must be a literal return null; @@ -133,12 +124,7 @@ public class StringSqlAggregator implements SqlAggregator Integer maxSizeBytes = null; if (arguments.size() > 2) { - RexNode maxBytes = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + RexNode maxBytes = inputAccessor.getField(aggregateCall.getArgList().get(2)); if (!maxBytes.isA(SqlKind.LITERAL)) { // maxBytes must be a literal return null; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/WindowSqlAggregate.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/WindowSqlAggregate.java index 7dd158d91f3..00cd391eab2 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/WindowSqlAggregate.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/WindowSqlAggregate.java @@ -20,14 +20,12 @@ package org.apache.druid.sql.calcite.expression; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlAggFunction; import org.apache.druid.java.util.common.UOE; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -55,12 +53,10 @@ public class WindowSqlAggregate implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java index 9c41d79070b..1cf79b6dc12 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java @@ -580,7 +580,11 @@ public class DruidQuery rowSignature, virtualColumnRegistry, rexBuilder, - partialQuery.getSelectProject(), + InputAccessor.buildFor( + rexBuilder, + rowSignature, + partialQuery.getSelectProject(), + null), aggregations, aggName, aggCall, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java new file mode 100644 index 00000000000..57b81c68536 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.rel; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.sql.calcite.expression.Expressions; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * Enables simpler access to input expressions. + * + * In case of aggregates it provides the constants transparently for aggregates. + */ +public class InputAccessor +{ + private final Project project; + private final ImmutableList constants; + private final RexBuilder rexBuilder; + private final RowSignature inputRowSignature; + private final int inputFieldCount; + + public static InputAccessor buildFor( + RexBuilder rexBuilder, + RowSignature inputRowSignature, + @Nullable Project project, + @Nullable ImmutableList constants) + { + return new InputAccessor(rexBuilder, inputRowSignature, project, constants); + } + + private InputAccessor( + RexBuilder rexBuilder, + RowSignature inputRowSignature, + Project project, + ImmutableList constants) + { + this.rexBuilder = rexBuilder; + this.inputRowSignature = inputRowSignature; + this.project = project; + this.constants = constants; + this.inputFieldCount = project != null ? project.getRowType().getFieldCount() : inputRowSignature.size(); + } + + public RexNode getField(int argIndex) + { + + if (argIndex < inputFieldCount) { + return Expressions.fromFieldAccess( + rexBuilder.getTypeFactory(), + inputRowSignature, + project, + argIndex); + } else { + return constants.get(argIndex - inputFieldCount); + } + } + + public List getFields(List argList) + { + return argList + .stream() + .map(i -> getField(i)) + .collect(Collectors.toList()); + } + + public @Nullable Project getProject() + { + return project; + } + + + public RexBuilder getRexBuilder() + { + return rexBuilder; + } + + + public RowSignature getInputRowSignature() + { + return inputRowSignature; + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java index 07c5544441d..4039ca8914a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java @@ -177,7 +177,11 @@ public class Windowing sourceRowSignature, null, rexBuilder, - partialQuery.getSelectProject(), + InputAccessor.buildFor( + rexBuilder, + sourceRowSignature, + partialQuery.getSelectProject(), + window.constants), Collections.emptyList(), aggName, aggregateCall, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java index 50bdf80771a..fecabd00ec3 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java @@ -20,7 +20,6 @@ package org.apache.druid.sql.calcite.rule; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -32,6 +31,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.filtration.Filtration; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -58,7 +58,7 @@ public class GroupByRules final RowSignature rowSignature, @Nullable final VirtualColumnRegistry virtualColumnRegistry, final RexBuilder rexBuilder, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final String name, final AggregateCall call, @@ -74,11 +74,7 @@ public class GroupByRules if (call.filterArg >= 0) { // AGG(xxx) FILTER(WHERE yyy) - final RexNode expression = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - call.filterArg); + final RexNode expression = inputAccessor.getField(call.filterArg); final DimFilter nonOptimizedFilter = Expressions.toFilter( plannerContext, @@ -136,12 +132,10 @@ public class GroupByRules final Aggregation retVal = sqlAggregator.toDruidAggregation( plannerContext, - rowSignature, virtualColumnRegistry, - rexBuilder, name, call, - project, + inputAccessor, existingAggregationsWithSameFilter, finalizeAggregations ); diff --git a/sql/src/test/resources/calcite/tests/window/aggregateConstant.sqlTest b/sql/src/test/resources/calcite/tests/window/aggregateConstant.sqlTest new file mode 100644 index 00000000000..16dbe924fdb --- /dev/null +++ b/sql/src/test/resources/calcite/tests/window/aggregateConstant.sqlTest @@ -0,0 +1,26 @@ +type: "operatorValidation" + +sql: | + SELECT + dim1, + count(333) OVER () cc + FROM foo + WHERE length(dim1)>0 + +expectedOperators: + - type: naivePartition + partitionColumns: [] + - type: "window" + processor: + type: "framedAgg" + frame: { peerType: "ROWS", lowUnbounded: true, lowOffset: 0, uppUnbounded: true, uppOffset: 0 } + aggregations: + - { type: "count", name: "w0" } + +expectedResults: + - ["10.1",5] + - ["2",5] + - ["1",5] + - ["def",5] + - ["abc",5] +