Support constant args in window functions (#15071)

Instead of passing the constants around in a new parameter; InputAccessor was introduced to take care of transparently handling the constants - this new class started picking up some copy-paste debris around field accesses; and made them a little bit more readble.
This commit is contained in:
Zoltan Haindrich 2023-10-08 08:44:25 +02:00 committed by GitHub
parent 7b869fd37a
commit b5a87fd89b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 357 additions and 497 deletions

View File

@ -21,8 +21,6 @@ package org.apache.druid.compressedbigdecimal;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -36,12 +34,12 @@ import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -71,12 +69,10 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
@ -88,13 +84,8 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
// fetch sum column expression
DruidExpression sumColumn = Expressions.toDruidExpression(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
inputAccessor.getInputRowSignature(),
inputAccessor.getField(aggregateCall.getArgList().get(0))
);
if (sumColumn == null) {
@ -114,12 +105,7 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
Integer size = null;
if (aggregateCall.getArgList().size() >= 2) {
RexNode sizeArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
RexNode sizeArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
size = ((Number) RexLiteral.value(sizeArg)).intValue();
}
@ -128,12 +114,7 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
Integer scale = null;
if (aggregateCall.getArgList().size() >= 3) {
RexNode scaleArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
RexNode scaleArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
scale = ((Number) RexLiteral.value(scaleArg)).intValue();
}
@ -141,12 +122,7 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
Boolean useStrictNumberParsing = null;
if (aggregateCall.getArgList().size() >= 4) {
RexNode useStrictNumberParsingArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(3)
);
RexNode useStrictNumberParsingArg = inputAccessor.getField(aggregateCall.getArgList().get(3));
useStrictNumberParsing = RexLiteral.booleanValue(useStrictNumberParsingArg);
}

View File

@ -20,8 +20,6 @@
package org.apache.druid.query.aggregation.tdigestsketch.sql;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -36,13 +34,12 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorFactory;
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchUtils;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -63,25 +60,18 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
final RexNode inputOperand = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0));
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
inputAccessor.getInputRowSignature(),
inputOperand
);
if (input == null) {
@ -93,12 +83,7 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
Integer compression = TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION;
if (aggregateCall.getArgList().size() > 1) {
RexNode compressionOperand = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
RexNode compressionOperand = inputAccessor.getField(aggregateCall.getArgList().get(1));
if (!compressionOperand.isA(SqlKind.LITERAL)) {
// compressionOperand must be a literal in order to plan.
return null;

View File

@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.tdigestsketch.sql;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -39,13 +37,12 @@ import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorF
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchToQuantilePostAggregator;
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchUtils;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -66,12 +63,10 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
@ -79,13 +74,8 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
// This is expected to be a tdigest sketch
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
inputAccessor.getInputRowSignature(),
inputAccessor.getField(aggregateCall.getArgList().get(0))
);
if (input == null) {
return null;
@ -95,12 +85,7 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
final String sketchName = StringUtils.format("%s:agg", name);
// this is expected to be quantile fraction
final RexNode quantileArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
final RexNode quantileArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
if (!quantileArg.isA(SqlKind.LITERAL)) {
// Quantile must be a literal in order to plan.
@ -110,12 +95,7 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
final double quantile = ((Number) RexLiteral.value(quantileArg)).floatValue();
Integer compression = TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION;
if (aggregateCall.getArgList().size() > 2) {
final RexNode compressionArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
final RexNode compressionArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
compression = ((Number) RexLiteral.value(compressionArg)).intValue();
}

View File

@ -20,9 +20,7 @@
package org.apache.druid.query.aggregation.datasketches.hll.sql;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
@ -36,7 +34,6 @@ import org.apache.druid.query.aggregation.datasketches.hll.HllSketchMergeAggrega
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
@ -44,6 +41,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -66,38 +64,26 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
// Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access
// for string columns.
final RexNode columnRexNode = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final RexNode columnRexNode = inputAccessor.getField(aggregateCall.getArgList().get(0));
final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, rowSignature, columnRexNode);
final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), columnRexNode);
if (columnArg == null) {
return null;
}
final int logK;
if (aggregateCall.getArgList().size() >= 2) {
final RexNode logKarg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
final RexNode logKarg = inputAccessor.getField(aggregateCall.getArgList().get(1));
if (!logKarg.isA(SqlKind.LITERAL)) {
// logK must be a literal in order to plan.
@ -111,12 +97,7 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
final String tgtHllType;
if (aggregateCall.getArgList().size() >= 3) {
final RexNode tgtHllTypeArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
final RexNode tgtHllTypeArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
if (!tgtHllTypeArg.isA(SqlKind.LITERAL)) {
// tgtHllType must be a literal in order to plan.
@ -132,9 +113,10 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
if (columnArg.isDirectColumnAccess()
&& rowSignature.getColumnType(columnArg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) {
&& inputAccessor.getInputRowSignature()
.getColumnType(columnArg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) {
aggregatorFactory = new HllSketchMergeAggregatorFactory(
aggregatorName,
columnArg.getDirectColumn(),

View File

@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.datasketches.quantiles.sql;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -37,14 +35,13 @@ import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchAg
import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchToQuantilePostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -75,25 +72,18 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
inputAccessor.getInputRowSignature(),
inputAccessor.getField(aggregateCall.getArgList().get(0))
);
if (input == null) {
return null;
@ -101,12 +91,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
final AggregatorFactory aggregatorFactory;
final String histogramName = StringUtils.format("%s:agg", name);
final RexNode probabilityArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
if (!probabilityArg.isA(SqlKind.LITERAL)) {
// Probability must be a literal in order to plan.
@ -117,12 +102,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
final int k;
if (aggregateCall.getArgList().size() >= 3) {
final RexNode resolutionArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
if (!resolutionArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan.

View File

@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.datasketches.quantiles.sql;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -35,14 +33,13 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.SketchQueryContext;
import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchAggregatorFactory;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -71,25 +68,18 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
inputAccessor.getInputRowSignature(),
inputAccessor.getField(aggregateCall.getArgList().get(0))
);
if (input == null) {
return null;
@ -100,12 +90,7 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
final int k;
if (aggregateCall.getArgList().size() >= 2) {
final RexNode resolutionArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
if (!resolutionArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan.

View File

@ -20,9 +20,7 @@
package org.apache.druid.query.aggregation.datasketches.theta.sql;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
@ -34,7 +32,6 @@ import org.apache.druid.query.aggregation.datasketches.theta.SketchMergeAggregat
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
@ -42,6 +39,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -60,38 +58,26 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
// Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access
// for string columns.
final RexNode columnRexNode = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final RexNode columnRexNode = inputAccessor.getField(aggregateCall.getArgList().get(0));
final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, rowSignature, columnRexNode);
final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), columnRexNode);
if (columnArg == null) {
return null;
}
final int sketchSize;
if (aggregateCall.getArgList().size() >= 2) {
final RexNode sketchSizeArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
final RexNode sketchSizeArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
if (!sketchSizeArg.isA(SqlKind.LITERAL)) {
// logK must be a literal in order to plan.
@ -107,9 +93,10 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
if (columnArg.isDirectColumnAccess()
&& rowSignature.getColumnType(columnArg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) {
&& inputAccessor.getInputRowSignature()
.getColumnType(columnArg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) {
aggregatorFactory = new SketchMergeAggregatorFactory(
aggregatorName,
columnArg.getDirectColumn(),

View File

@ -20,9 +20,7 @@
package org.apache.druid.query.aggregation.datasketches.tuple.sql;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -38,7 +36,6 @@ import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketc
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
@ -46,6 +43,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -69,12 +67,10 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
@ -86,12 +82,7 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator
final int nominalEntries;
final int metricExpressionEndIndex;
final int lastArgIndex = argList.size() - 1;
final RexNode potentialNominalEntriesArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
argList.get(lastArgIndex)
);
final RexNode potentialNominalEntriesArg = inputAccessor.getField(argList.get(lastArgIndex));
if (potentialNominalEntriesArg.isA(SqlKind.LITERAL) &&
RexLiteral.value(potentialNominalEntriesArg) instanceof Number) {
@ -107,16 +98,11 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator
for (int i = 0; i <= metricExpressionEndIndex; i++) {
final String fieldName;
final RexNode columnRexNode = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
argList.get(i)
);
final RexNode columnRexNode = inputAccessor.getField(argList.get(i));
final DruidExpression columnArg = Expressions.toDruidExpression(
plannerContext,
rowSignature,
inputAccessor.getInputRowSignature(),
columnRexNode
);
if (columnArg == null) {
@ -124,9 +110,10 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator
}
if (columnArg.isDirectColumnAccess() &&
rowSignature.getColumnType(columnArg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) {
inputAccessor.getInputRowSignature()
.getColumnType(columnArg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) {
fieldName = columnArg.getDirectColumn();
} else {
final RelDataType dataType = columnRexNode.getType();

View File

@ -20,8 +20,6 @@
package org.apache.druid.query.aggregation.bloom.sql;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -38,13 +36,13 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.dimension.ExtractionDimensionSpec;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -65,25 +63,18 @@ public class BloomFilterSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
final RexNode inputOperand = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0));
final DruidExpression input = Expressions.toDruidExpression(
plannerContext,
rowSignature,
inputAccessor.getInputRowSignature(),
inputOperand
);
if (input == null) {
@ -92,12 +83,7 @@ public class BloomFilterSqlAggregator implements SqlAggregator
final AggregatorFactory aggregatorFactory;
final String aggName = StringUtils.format("%s:agg", name);
final RexNode maxNumEntriesOperand = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
final RexNode maxNumEntriesOperand = inputAccessor.getField(aggregateCall.getArgList().get(1));
if (!maxNumEntriesOperand.isA(SqlKind.LITERAL)) {
// maxNumEntriesOperand must be a literal in order to plan.

View File

@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.histogram.sql;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -38,13 +36,12 @@ import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogram;
import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogramAggregatorFactory;
import org.apache.druid.query.aggregation.histogram.QuantilePostAggregator;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -65,25 +62,18 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
inputAccessor.getInputRowSignature(),
inputAccessor.getField(aggregateCall.getArgList().get(0))
);
if (input == null) {
return null;
@ -91,12 +81,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
final AggregatorFactory aggregatorFactory;
final String histogramName = StringUtils.format("%s:agg", name);
final RexNode probabilityArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
if (!probabilityArg.isA(SqlKind.LITERAL)) {
// Probability must be a literal in order to plan.
@ -107,12 +92,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
final int numBuckets;
if (aggregateCall.getArgList().size() >= 3) {
final RexNode numBucketsArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
final RexNode numBucketsArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
if (!numBucketsArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan.
@ -126,12 +106,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
final double lowerLimit;
if (aggregateCall.getArgList().size() >= 4) {
final RexNode lowerLimitArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(3)
);
final RexNode lowerLimitArg = inputAccessor.getField(aggregateCall.getArgList().get(3));
if (!lowerLimitArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan.
@ -145,12 +120,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
final double upperLimit;
if (aggregateCall.getArgList().size() >= 5) {
final RexNode upperLimitArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(4)
);
final RexNode upperLimitArg = inputAccessor.getField(aggregateCall.getArgList().get(4));
if (!upperLimitArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan.
@ -164,12 +134,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
final FixedBucketsHistogram.OutlierHandlingMode outlierHandlingMode;
if (aggregateCall.getArgList().size() >= 6) {
final RexNode outlierHandlingModeArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(5)
);
final RexNode outlierHandlingModeArg = inputAccessor.getField(aggregateCall.getArgList().get(5));
if (!outlierHandlingModeArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan.

View File

@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.histogram.sql;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -39,14 +37,13 @@ import org.apache.druid.query.aggregation.histogram.ApproximateHistogramAggregat
import org.apache.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory;
import org.apache.druid.query.aggregation.histogram.QuantilePostAggregator;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -67,25 +64,18 @@ public class QuantileSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
inputAccessor.getInputRowSignature(),
inputAccessor.getField(aggregateCall.getArgList().get(0))
);
if (input == null) {
return null;
@ -93,12 +83,7 @@ public class QuantileSqlAggregator implements SqlAggregator
final AggregatorFactory aggregatorFactory;
final String histogramName = StringUtils.format("%s:agg", name);
final RexNode probabilityArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
if (!probabilityArg.isA(SqlKind.LITERAL)) {
// Probability must be a literal in order to plan.
@ -109,12 +94,7 @@ public class QuantileSqlAggregator implements SqlAggregator
final int resolution;
if (aggregateCall.getArgList().size() >= 3) {
final RexNode resolutionArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
if (!resolutionArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan.
@ -170,7 +150,10 @@ public class QuantileSqlAggregator implements SqlAggregator
// No existing match found. Create a new one.
if (input.isDirectColumnAccess()) {
if (rowSignature.getColumnType(input.getDirectColumn()).map(type -> type.is(ValueType.COMPLEX)).orElse(false)) {
if (inputAccessor.getInputRowSignature()
.getColumnType(input.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) {
aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory(
histogramName,
input.getDirectColumn(),

View File

@ -21,9 +21,7 @@ package org.apache.druid.query.aggregation.variance.sql;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
@ -40,15 +38,14 @@ import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.table.RowSignatures;
@ -77,25 +74,19 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
final RexNode inputOperand = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0));
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
inputAccessor.getInputRowSignature(),
inputOperand
);
if (input == null) {

View File

@ -20,14 +20,13 @@
package org.apache.druid.sql.calcite.aggregation;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import javax.annotation.Nullable;
import java.util.List;
@ -48,28 +47,24 @@ public class Aggregations
*
* 1) They can take direct field accesses or expressions as inputs.
* 2) They cannot implicitly cast strings to numbers when using a direct field access.
*
* @param plannerContext SQL planner context
* @param rowSignature input row signature
* @param call aggregate call object
* @param project project that should be applied before aggregation; may be null
* @param inputAccessor gives access to input fields and schema
*
* @return list of expressions corresponding to aggregator arguments, or null if any cannot be translated
*/
@Nullable
public static List<DruidExpression> getArgumentsForSimpleAggregator(
final RexBuilder rexBuilder,
final PlannerContext plannerContext,
final RowSignature rowSignature,
final AggregateCall call,
@Nullable final Project project
final InputAccessor inputAccessor
)
{
final List<DruidExpression> args = call
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.map(rexNode -> toDruidExpressionForNumericAggregator(plannerContext, rowSignature, rexNode))
.map(i -> inputAccessor.getField(i))
.map(rexNode -> toDruidExpressionForNumericAggregator(plannerContext, inputAccessor.getInputRowSignature(), rexNode))
.collect(Collectors.toList());
if (args.stream().noneMatch(Objects::isNull)) {

View File

@ -20,8 +20,6 @@
package org.apache.druid.sql.calcite.aggregation;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
@ -30,8 +28,8 @@ import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -66,24 +64,20 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
return delegate.toDruidAggregation(
plannerContext,
rowSignature,
virtualColumnRegistry,
rexBuilder,
name,
aggregateCall,
project,
inputAccessor,
existingAggregations,
finalizeAggregations
);

View File

@ -25,6 +25,7 @@ import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -42,6 +43,44 @@ public interface SqlAggregator
*/
SqlAggFunction calciteFunction();
/**
* Returns a Druid Aggregation corresponding to a SQL {@link AggregateCall}. This method should ignore filters;
* they will be applied to your aggregator in a later step.
*
* @param plannerContext SQL planner context
* @param virtualColumnRegistry re-usable virtual column references
* @param name desired output name of the aggregation
* @param aggregateCall aggregate call object
* @param inputAccessor gives access to input fields and schema
* @param existingAggregations existing aggregations for this query; useful for re-using aggregations. May be safely
* ignored if you do not want to re-use existing aggregations.
* @param finalizeAggregations true if this query should include explicit finalization for all of its
* aggregators, where required. This is set for subqueries where Druid's native query
* layer does not do this automatically.
* @return aggregation, or null if the call cannot be translated
*/
@Nullable
default Aggregation toDruidAggregation(
PlannerContext plannerContext,
VirtualColumnRegistry virtualColumnRegistry,
String name,
AggregateCall aggregateCall,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
return toDruidAggregation(plannerContext,
inputAccessor.getInputRowSignature(),
virtualColumnRegistry,
inputAccessor.getRexBuilder(),
name,
aggregateCall,
inputAccessor.getProject(),
existingAggregations,
finalizeAggregations);
}
/**
* Returns a Druid Aggregation corresponding to a SQL {@link AggregateCall}. This method should ignore filters;
* they will be applied to your aggregator in a later step.
@ -62,7 +101,7 @@ public interface SqlAggregator
* @return aggregation, or null if the call cannot be translated
*/
@Nullable
Aggregation toDruidAggregation(
default Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
@ -72,5 +111,8 @@ public interface SqlAggregator
Project project,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
);
)
{
throw new RuntimeException("unimplemented fallback method!");
}
}

View File

@ -21,8 +21,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableSet;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -39,18 +37,17 @@ import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.List;
import java.util.stream.Collectors;
public class ArrayConcatSqlAggregator implements SqlAggregator
{
@ -67,21 +64,15 @@ public class ArrayConcatSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
final List<RexNode> arguments = aggregateCall
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.collect(Collectors.toList());
final List<RexNode> arguments = inputAccessor.getFields(aggregateCall.getArgList());
Integer maxSizeBytes = null;
if (arguments.size() > 1) {
@ -92,7 +83,7 @@ public class ArrayConcatSqlAggregator implements SqlAggregator
}
maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue();
}
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, arguments.get(0));
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0));
final ExprMacroTable macroTable = plannerContext.getPlannerToolbox().exprMacroTable();
final String fieldName;

View File

@ -21,9 +21,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableSet;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -41,18 +39,17 @@ import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.List;
import java.util.stream.Collectors;
public class ArraySqlAggregator implements SqlAggregator
{
@ -69,21 +66,16 @@ public class ArraySqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
final List<RexNode> arguments = aggregateCall
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.collect(Collectors.toList());
final List<RexNode> arguments =
inputAccessor.getFields(aggregateCall.getArgList());
Integer maxSizeBytes = null;
if (arguments.size() > 1) {
@ -94,7 +86,7 @@ public class ArraySqlAggregator implements SqlAggregator
}
maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue();
}
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, arguments.get(0));
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0));
if (arg == null) {
// can't translate argument
return null;

View File

@ -22,8 +22,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@ -33,14 +31,13 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -58,23 +55,19 @@ public class AvgSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
final List<DruidExpression> arguments = Aggregations.getArgumentsForSimpleAggregator(
rexBuilder,
plannerContext,
rowSignature,
aggregateCall,
project
inputAccessor
);
if (arguments == null) {
@ -85,11 +78,11 @@ public class AvgSqlAggregator implements SqlAggregator
final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory(
countName,
plannerContext,
rowSignature,
inputAccessor.getInputRowSignature(),
virtualColumnRegistry,
rexBuilder,
inputAccessor.getRexBuilder(),
aggregateCall,
project
inputAccessor
);
final DruidExpression arg = Iterables.getOnlyElement(arguments);
@ -108,12 +101,8 @@ public class AvgSqlAggregator implements SqlAggregator
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
} else {
final RexNode resolutionArg = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
final RexNode resolutionArg = inputAccessor.getField(
Iterables.getOnlyElement(aggregateCall.getArgList()));
fieldName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(arg, resolutionArg.getType());
}

View File

@ -21,8 +21,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableSet;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
@ -41,12 +39,12 @@ import org.apache.druid.query.filter.NotDimFilter;
import org.apache.druid.query.filter.NullFilter;
import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -122,12 +120,10 @@ public class BitwiseSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
@ -135,8 +131,8 @@ public class BitwiseSqlAggregator implements SqlAggregator
final List<DruidExpression> arguments = aggregateCall
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode))
.map(i -> inputAccessor.getField(i))
.map(rexNode -> Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode))
.collect(Collectors.toList());
if (arguments.stream().anyMatch(Objects::isNull)) {

View File

@ -22,9 +22,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
@ -42,7 +40,6 @@ import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFact
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
@ -50,6 +47,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -72,26 +70,20 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
// Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access
// for string columns.
final RexNode rexNode = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
final RexNode rexNode = inputAccessor.getField(
Iterables.getOnlyElement(aggregateCall.getArgList()));
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, rexNode);
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode);
if (arg == null) {
return null;
}
@ -100,7 +92,10 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
if (arg.isDirectColumnAccess()
&& rowSignature.getColumnType(arg.getDirectColumn()).map(type -> type.is(ValueType.COMPLEX)).orElse(false)) {
&& inputAccessor.getInputRowSignature()
.getColumnType(arg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) {
aggregatorFactory = new HyperUniquesAggregatorFactory(aggregatorName, arg.getDirectColumn(), false, true);
} else {
final RelDataType dataType = rexNode.getType();

View File

@ -22,7 +22,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -40,6 +39,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -69,15 +69,10 @@ public class CountSqlAggregator implements SqlAggregator
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final AggregateCall aggregateCall,
final Project project
final InputAccessor inputAccessor
)
{
final RexNode rexNode = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
final RexNode rexNode = inputAccessor.getField(Iterables.getOnlyElement(aggregateCall.getArgList()));
if (rexNode.getType().isNullable()) {
final DimFilter nonNullFilter = Expressions.toFilter(
@ -102,28 +97,25 @@ public class CountSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
final List<DruidExpression> args = Aggregations.getArgumentsForSimpleAggregator(
rexBuilder,
plannerContext,
rowSignature,
aggregateCall,
project
inputAccessor
);
if (args == null) {
return null;
}
// FIXME: is-all-literal
if (args.isEmpty()) {
// COUNT(*)
return Aggregation.create(new CountAggregatorFactory(name));
@ -132,12 +124,10 @@ public class CountSqlAggregator implements SqlAggregator
if (plannerContext.getPlannerConfig().isUseApproximateCountDistinct()) {
return approxCountDistinctAggregator.toDruidAggregation(
plannerContext,
rowSignature,
virtualColumnRegistry,
rexBuilder,
name,
aggregateCall,
project,
inputAccessor,
existingAggregations,
finalizeAggregations
);
@ -150,11 +140,11 @@ public class CountSqlAggregator implements SqlAggregator
AggregatorFactory theCount = createCountAggregatorFactory(
name,
plannerContext,
rowSignature,
inputAccessor.getInputRowSignature(),
virtualColumnRegistry,
rexBuilder,
inputAccessor.getRexBuilder(),
aggregateCall,
project
inputAccessor
);
return Aggregation.create(theCount);

View File

@ -20,9 +20,7 @@
package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -53,19 +51,18 @@ import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
public class EarliestLatestAnySqlAggregator implements SqlAggregator
{
@ -180,23 +177,17 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
final List<RexNode> rexNodes = aggregateCall
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.collect(Collectors.toList());
final List<RexNode> rexNodes = inputAccessor.getFields(aggregateCall.getArgList());
final List<DruidExpression> args = Expressions.toDruidExpressions(plannerContext, rowSignature, rexNodes);
final List<DruidExpression> args = Expressions.toDruidExpressions(plannerContext, inputAccessor.getInputRowSignature(), rexNodes);
if (args == null) {
return null;
@ -216,7 +207,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
final String fieldName = getColumnName(plannerContext, virtualColumnRegistry, args.get(0), rexNodes.get(0));
if (!rowSignature.contains(ColumnHolder.TIME_COLUMN_NAME) && (aggregatorType == AggregatorType.LATEST || aggregatorType == AggregatorType.EARLIEST)) {
if (!inputAccessor.getInputRowSignature().contains(ColumnHolder.TIME_COLUMN_NAME)
&& (aggregatorType == AggregatorType.LATEST || aggregatorType == AggregatorType.EARLIEST)) {
// This code is being run as part of the exploratory volcano planner, currently, the definition of these
// aggregators does not tell Calcite that they depend on a __time column being in existence, instead we are
// allowing the volcano planner to explore paths that put projections which eliminate the time column in between

View File

@ -20,8 +20,6 @@
package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -38,19 +36,18 @@ import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
public class EarliestLatestBySqlAggregator implements SqlAggregator
{
@ -76,23 +73,17 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
final List<RexNode> rexNodes = aggregateCall
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.collect(Collectors.toList());
final List<RexNode> rexNodes = inputAccessor.getFields(aggregateCall.getArgList());
final List<DruidExpression> args = Expressions.toDruidExpressions(plannerContext, rowSignature, rexNodes);
final List<DruidExpression> args = Expressions.toDruidExpressions(plannerContext, inputAccessor.getInputRowSignature(), rexNodes);
if (args == null) {
return null;

View File

@ -22,7 +22,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@ -34,6 +33,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -53,24 +53,22 @@ public class GroupingSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
List<String> arguments = aggregateCall.getArgList()
.stream()
.map(i -> getColumnName(
plannerContext,
rowSignature,
project,
inputAccessor.getInputRowSignature(),
inputAccessor.getProject(),
virtualColumnRegistry,
rexBuilder.getTypeFactory(),
inputAccessor.getRexBuilder().getTypeFactory(),
i
))
.filter(Objects::nonNull)

View File

@ -21,18 +21,16 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlInternalOperators;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -59,12 +57,10 @@ public class LiteralSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
@ -73,7 +69,7 @@ public class LiteralSqlAggregator implements SqlAggregator
return null;
}
final RexNode literal = aggregateCall.rexList.get(0);
final DruidExpression expr = Expressions.toDruidExpression(plannerContext, rowSignature, literal);
final DruidExpression expr = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), literal);
if (expr == null) {
return null;

View File

@ -21,18 +21,16 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.druid.error.DruidException;
import org.apache.druid.error.InvalidSqlInput;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -57,12 +55,10 @@ public abstract class SimpleSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
@ -72,11 +68,9 @@ public abstract class SimpleSqlAggregator implements SqlAggregator
}
final List<DruidExpression> arguments = Aggregations.getArgumentsForSimpleAggregator(
rexBuilder,
plannerContext,
rowSignature,
aggregateCall,
project
inputAccessor
);
if (arguments == null) {

View File

@ -21,9 +21,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableSet;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
@ -47,13 +45,13 @@ import org.apache.druid.query.filter.NotDimFilter;
import org.apache.druid.query.filter.NullFilter;
import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.table.RowSignatures;
@ -89,12 +87,10 @@ public class StringSqlAggregator implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
@ -102,20 +98,15 @@ public class StringSqlAggregator implements SqlAggregator
final List<DruidExpression> arguments = aggregateCall
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode))
.map(i -> inputAccessor.getField(i))
.map(rexNode -> Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode))
.collect(Collectors.toList());
if (arguments.stream().anyMatch(Objects::isNull)) {
return null;
}
RexNode separatorNode = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
RexNode separatorNode = inputAccessor.getField(aggregateCall.getArgList().get(1));
if (!separatorNode.isA(SqlKind.LITERAL)) {
// separator must be a literal
return null;
@ -133,12 +124,7 @@ public class StringSqlAggregator implements SqlAggregator
Integer maxSizeBytes = null;
if (arguments.size() > 2) {
RexNode maxBytes = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
RexNode maxBytes = inputAccessor.getField(aggregateCall.getArgList().get(2));
if (!maxBytes.isA(SqlKind.LITERAL)) {
// maxBytes must be a literal
return null;

View File

@ -20,14 +20,12 @@
package org.apache.druid.sql.calcite.expression;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -55,12 +53,10 @@ public class WindowSqlAggregate implements SqlAggregator
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)

View File

@ -580,7 +580,11 @@ public class DruidQuery
rowSignature,
virtualColumnRegistry,
rexBuilder,
partialQuery.getSelectProject(),
InputAccessor.buildFor(
rexBuilder,
rowSignature,
partialQuery.getSelectProject(),
null),
aggregations,
aggName,
aggCall,

View File

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.rel;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.Expressions;
import javax.annotation.Nullable;
import java.util.List;
import java.util.stream.Collectors;
/**
* Enables simpler access to input expressions.
*
* In case of aggregates it provides the constants transparently for aggregates.
*/
public class InputAccessor
{
private final Project project;
private final ImmutableList<RexLiteral> constants;
private final RexBuilder rexBuilder;
private final RowSignature inputRowSignature;
private final int inputFieldCount;
public static InputAccessor buildFor(
RexBuilder rexBuilder,
RowSignature inputRowSignature,
@Nullable Project project,
@Nullable ImmutableList<RexLiteral> constants)
{
return new InputAccessor(rexBuilder, inputRowSignature, project, constants);
}
private InputAccessor(
RexBuilder rexBuilder,
RowSignature inputRowSignature,
Project project,
ImmutableList<RexLiteral> constants)
{
this.rexBuilder = rexBuilder;
this.inputRowSignature = inputRowSignature;
this.project = project;
this.constants = constants;
this.inputFieldCount = project != null ? project.getRowType().getFieldCount() : inputRowSignature.size();
}
public RexNode getField(int argIndex)
{
if (argIndex < inputFieldCount) {
return Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
inputRowSignature,
project,
argIndex);
} else {
return constants.get(argIndex - inputFieldCount);
}
}
public List<RexNode> getFields(List<Integer> argList)
{
return argList
.stream()
.map(i -> getField(i))
.collect(Collectors.toList());
}
public @Nullable Project getProject()
{
return project;
}
public RexBuilder getRexBuilder()
{
return rexBuilder;
}
public RowSignature getInputRowSignature()
{
return inputRowSignature;
}
}

View File

@ -177,7 +177,11 @@ public class Windowing
sourceRowSignature,
null,
rexBuilder,
partialQuery.getSelectProject(),
InputAccessor.buildFor(
rexBuilder,
sourceRowSignature,
partialQuery.getSelectProject(),
window.constants),
Collections.emptyList(),
aggName,
aggregateCall,

View File

@ -20,7 +20,6 @@
package org.apache.druid.sql.calcite.rule;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.druid.query.aggregation.AggregatorFactory;
@ -32,6 +31,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
@ -58,7 +58,7 @@ public class GroupByRules
final RowSignature rowSignature,
@Nullable final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final Project project,
final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations,
final String name,
final AggregateCall call,
@ -74,11 +74,7 @@ public class GroupByRules
if (call.filterArg >= 0) {
// AGG(xxx) FILTER(WHERE yyy)
final RexNode expression = Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
rowSignature,
project,
call.filterArg);
final RexNode expression = inputAccessor.getField(call.filterArg);
final DimFilter nonOptimizedFilter = Expressions.toFilter(
plannerContext,
@ -136,12 +132,10 @@ public class GroupByRules
final Aggregation retVal = sqlAggregator.toDruidAggregation(
plannerContext,
rowSignature,
virtualColumnRegistry,
rexBuilder,
name,
call,
project,
inputAccessor,
existingAggregationsWithSameFilter,
finalizeAggregations
);

View File

@ -0,0 +1,26 @@
type: "operatorValidation"
sql: |
SELECT
dim1,
count(333) OVER () cc
FROM foo
WHERE length(dim1)>0
expectedOperators:
- type: naivePartition
partitionColumns: []
- type: "window"
processor:
type: "framedAgg"
frame: { peerType: "ROWS", lowUnbounded: true, lowOffset: 0, uppUnbounded: true, uppOffset: 0 }
aggregations:
- { type: "count", name: "w0" }
expectedResults:
- ["10.1",5]
- ["2",5]
- ["1",5]
- ["def",5]
- ["abc",5]