Support constant args in window functions (#15071)

Instead of passing the constants around in a new parameter; InputAccessor was introduced to take care of transparently handling the constants - this new class started picking up some copy-paste debris around field accesses; and made them a little bit more readble.
This commit is contained in:
Zoltan Haindrich 2023-10-08 08:44:25 +02:00 committed by GitHub
parent 7b869fd37a
commit b5a87fd89b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 357 additions and 497 deletions

View File

@ -21,8 +21,6 @@ package org.apache.druid.compressedbigdecimal;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -36,12 +34,12 @@ import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -71,12 +69,10 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
@ -88,13 +84,8 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
// fetch sum column expression // fetch sum column expression
DruidExpression sumColumn = Expressions.toDruidExpression( DruidExpression sumColumn = Expressions.toDruidExpression(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
Expressions.fromFieldAccess( inputAccessor.getField(aggregateCall.getArgList().get(0))
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
); );
if (sumColumn == null) { if (sumColumn == null) {
@ -114,12 +105,7 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
Integer size = null; Integer size = null;
if (aggregateCall.getArgList().size() >= 2) { if (aggregateCall.getArgList().size() >= 2) {
RexNode sizeArg = Expressions.fromFieldAccess( RexNode sizeArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
size = ((Number) RexLiteral.value(sizeArg)).intValue(); size = ((Number) RexLiteral.value(sizeArg)).intValue();
} }
@ -128,12 +114,7 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
Integer scale = null; Integer scale = null;
if (aggregateCall.getArgList().size() >= 3) { if (aggregateCall.getArgList().size() >= 3) {
RexNode scaleArg = Expressions.fromFieldAccess( RexNode scaleArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
scale = ((Number) RexLiteral.value(scaleArg)).intValue(); scale = ((Number) RexLiteral.value(scaleArg)).intValue();
} }
@ -141,12 +122,7 @@ public abstract class CompressedBigDecimalSqlAggregatorBase implements SqlAggreg
Boolean useStrictNumberParsing = null; Boolean useStrictNumberParsing = null;
if (aggregateCall.getArgList().size() >= 4) { if (aggregateCall.getArgList().size() >= 4) {
RexNode useStrictNumberParsingArg = Expressions.fromFieldAccess( RexNode useStrictNumberParsingArg = inputAccessor.getField(aggregateCall.getArgList().get(3));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(3)
);
useStrictNumberParsing = RexLiteral.booleanValue(useStrictNumberParsingArg); useStrictNumberParsing = RexLiteral.booleanValue(useStrictNumberParsingArg);
} }

View File

@ -20,8 +20,6 @@
package org.apache.druid.query.aggregation.tdigestsketch.sql; package org.apache.druid.query.aggregation.tdigestsketch.sql;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -36,13 +34,12 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorFactory; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorFactory;
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchUtils; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchUtils;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -63,25 +60,18 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
{ {
final RexNode inputOperand = Expressions.fromFieldAccess( final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
inputOperand inputOperand
); );
if (input == null) { if (input == null) {
@ -93,12 +83,7 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
Integer compression = TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION; Integer compression = TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION;
if (aggregateCall.getArgList().size() > 1) { if (aggregateCall.getArgList().size() > 1) {
RexNode compressionOperand = Expressions.fromFieldAccess( RexNode compressionOperand = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
if (!compressionOperand.isA(SqlKind.LITERAL)) { if (!compressionOperand.isA(SqlKind.LITERAL)) {
// compressionOperand must be a literal in order to plan. // compressionOperand must be a literal in order to plan.
return null; return null;

View File

@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.tdigestsketch.sql;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -39,13 +37,12 @@ import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorF
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchToQuantilePostAggregator; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchToQuantilePostAggregator;
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchUtils; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchUtils;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -66,12 +63,10 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
@ -79,13 +74,8 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
// This is expected to be a tdigest sketch // This is expected to be a tdigest sketch
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
Expressions.fromFieldAccess( inputAccessor.getField(aggregateCall.getArgList().get(0))
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
); );
if (input == null) { if (input == null) {
return null; return null;
@ -95,12 +85,7 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
final String sketchName = StringUtils.format("%s:agg", name); final String sketchName = StringUtils.format("%s:agg", name);
// this is expected to be quantile fraction // this is expected to be quantile fraction
final RexNode quantileArg = Expressions.fromFieldAccess( final RexNode quantileArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
if (!quantileArg.isA(SqlKind.LITERAL)) { if (!quantileArg.isA(SqlKind.LITERAL)) {
// Quantile must be a literal in order to plan. // Quantile must be a literal in order to plan.
@ -110,12 +95,7 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
final double quantile = ((Number) RexLiteral.value(quantileArg)).floatValue(); final double quantile = ((Number) RexLiteral.value(quantileArg)).floatValue();
Integer compression = TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION; Integer compression = TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION;
if (aggregateCall.getArgList().size() > 2) { if (aggregateCall.getArgList().size() > 2) {
final RexNode compressionArg = Expressions.fromFieldAccess( final RexNode compressionArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
compression = ((Number) RexLiteral.value(compressionArg)).intValue(); compression = ((Number) RexLiteral.value(compressionArg)).intValue();
} }

View File

@ -20,9 +20,7 @@
package org.apache.druid.query.aggregation.datasketches.hll.sql; package org.apache.druid.query.aggregation.datasketches.hll.sql;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlKind;
@ -36,7 +34,6 @@ import org.apache.druid.query.aggregation.datasketches.hll.HllSketchMergeAggrega
import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
@ -44,6 +41,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -66,38 +64,26 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
{ {
// Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access // Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access
// for string columns. // for string columns.
final RexNode columnRexNode = Expressions.fromFieldAccess( final RexNode columnRexNode = inputAccessor.getField(aggregateCall.getArgList().get(0));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, rowSignature, columnRexNode); final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), columnRexNode);
if (columnArg == null) { if (columnArg == null) {
return null; return null;
} }
final int logK; final int logK;
if (aggregateCall.getArgList().size() >= 2) { if (aggregateCall.getArgList().size() >= 2) {
final RexNode logKarg = Expressions.fromFieldAccess( final RexNode logKarg = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
if (!logKarg.isA(SqlKind.LITERAL)) { if (!logKarg.isA(SqlKind.LITERAL)) {
// logK must be a literal in order to plan. // logK must be a literal in order to plan.
@ -111,12 +97,7 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
final String tgtHllType; final String tgtHllType;
if (aggregateCall.getArgList().size() >= 3) { if (aggregateCall.getArgList().size() >= 3) {
final RexNode tgtHllTypeArg = Expressions.fromFieldAccess( final RexNode tgtHllTypeArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
if (!tgtHllTypeArg.isA(SqlKind.LITERAL)) { if (!tgtHllTypeArg.isA(SqlKind.LITERAL)) {
// tgtHllType must be a literal in order to plan. // tgtHllType must be a literal in order to plan.
@ -132,7 +113,8 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
if (columnArg.isDirectColumnAccess() if (columnArg.isDirectColumnAccess()
&& rowSignature.getColumnType(columnArg.getDirectColumn()) && inputAccessor.getInputRowSignature()
.getColumnType(columnArg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX)) .map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) { .orElse(false)) {
aggregatorFactory = new HllSketchMergeAggregatorFactory( aggregatorFactory = new HllSketchMergeAggregatorFactory(

View File

@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.datasketches.quantiles.sql;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -37,14 +35,13 @@ import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchAg
import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchToQuantilePostAggregator; import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchToQuantilePostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -75,25 +72,18 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
{ {
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
Expressions.fromFieldAccess( inputAccessor.getField(aggregateCall.getArgList().get(0))
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
); );
if (input == null) { if (input == null) {
return null; return null;
@ -101,12 +91,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
final AggregatorFactory aggregatorFactory; final AggregatorFactory aggregatorFactory;
final String histogramName = StringUtils.format("%s:agg", name); final String histogramName = StringUtils.format("%s:agg", name);
final RexNode probabilityArg = Expressions.fromFieldAccess( final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
if (!probabilityArg.isA(SqlKind.LITERAL)) { if (!probabilityArg.isA(SqlKind.LITERAL)) {
// Probability must be a literal in order to plan. // Probability must be a literal in order to plan.
@ -117,12 +102,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
final int k; final int k;
if (aggregateCall.getArgList().size() >= 3) { if (aggregateCall.getArgList().size() >= 3) {
final RexNode resolutionArg = Expressions.fromFieldAccess( final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
if (!resolutionArg.isA(SqlKind.LITERAL)) { if (!resolutionArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan. // Resolution must be a literal in order to plan.

View File

@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.datasketches.quantiles.sql;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -35,14 +33,13 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.SketchQueryContext; import org.apache.druid.query.aggregation.datasketches.SketchQueryContext;
import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchAggregatorFactory; import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchAggregatorFactory;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -71,25 +68,18 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
{ {
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
Expressions.fromFieldAccess( inputAccessor.getField(aggregateCall.getArgList().get(0))
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
); );
if (input == null) { if (input == null) {
return null; return null;
@ -100,12 +90,7 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
final int k; final int k;
if (aggregateCall.getArgList().size() >= 2) { if (aggregateCall.getArgList().size() >= 2) {
final RexNode resolutionArg = Expressions.fromFieldAccess( final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
if (!resolutionArg.isA(SqlKind.LITERAL)) { if (!resolutionArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan. // Resolution must be a literal in order to plan.

View File

@ -20,9 +20,7 @@
package org.apache.druid.query.aggregation.datasketches.theta.sql; package org.apache.druid.query.aggregation.datasketches.theta.sql;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlKind;
@ -34,7 +32,6 @@ import org.apache.druid.query.aggregation.datasketches.theta.SketchMergeAggregat
import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
@ -42,6 +39,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -60,38 +58,26 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
{ {
// Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access // Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access
// for string columns. // for string columns.
final RexNode columnRexNode = Expressions.fromFieldAccess( final RexNode columnRexNode = inputAccessor.getField(aggregateCall.getArgList().get(0));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, rowSignature, columnRexNode); final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), columnRexNode);
if (columnArg == null) { if (columnArg == null) {
return null; return null;
} }
final int sketchSize; final int sketchSize;
if (aggregateCall.getArgList().size() >= 2) { if (aggregateCall.getArgList().size() >= 2) {
final RexNode sketchSizeArg = Expressions.fromFieldAccess( final RexNode sketchSizeArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
if (!sketchSizeArg.isA(SqlKind.LITERAL)) { if (!sketchSizeArg.isA(SqlKind.LITERAL)) {
// logK must be a literal in order to plan. // logK must be a literal in order to plan.
@ -107,7 +93,8 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
if (columnArg.isDirectColumnAccess() if (columnArg.isDirectColumnAccess()
&& rowSignature.getColumnType(columnArg.getDirectColumn()) && inputAccessor.getInputRowSignature()
.getColumnType(columnArg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX)) .map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) { .orElse(false)) {
aggregatorFactory = new SketchMergeAggregatorFactory( aggregatorFactory = new SketchMergeAggregatorFactory(

View File

@ -20,9 +20,7 @@
package org.apache.druid.query.aggregation.datasketches.tuple.sql; package org.apache.druid.query.aggregation.datasketches.tuple.sql;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -38,7 +36,6 @@ import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketc
import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
@ -46,6 +43,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -69,12 +67,10 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
@ -86,12 +82,7 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator
final int nominalEntries; final int nominalEntries;
final int metricExpressionEndIndex; final int metricExpressionEndIndex;
final int lastArgIndex = argList.size() - 1; final int lastArgIndex = argList.size() - 1;
final RexNode potentialNominalEntriesArg = Expressions.fromFieldAccess( final RexNode potentialNominalEntriesArg = inputAccessor.getField(argList.get(lastArgIndex));
rexBuilder.getTypeFactory(),
rowSignature,
project,
argList.get(lastArgIndex)
);
if (potentialNominalEntriesArg.isA(SqlKind.LITERAL) && if (potentialNominalEntriesArg.isA(SqlKind.LITERAL) &&
RexLiteral.value(potentialNominalEntriesArg) instanceof Number) { RexLiteral.value(potentialNominalEntriesArg) instanceof Number) {
@ -107,16 +98,11 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator
for (int i = 0; i <= metricExpressionEndIndex; i++) { for (int i = 0; i <= metricExpressionEndIndex; i++) {
final String fieldName; final String fieldName;
final RexNode columnRexNode = Expressions.fromFieldAccess( final RexNode columnRexNode = inputAccessor.getField(argList.get(i));
rexBuilder.getTypeFactory(),
rowSignature,
project,
argList.get(i)
);
final DruidExpression columnArg = Expressions.toDruidExpression( final DruidExpression columnArg = Expressions.toDruidExpression(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
columnRexNode columnRexNode
); );
if (columnArg == null) { if (columnArg == null) {
@ -124,7 +110,8 @@ public class ArrayOfDoublesSketchSqlAggregator implements SqlAggregator
} }
if (columnArg.isDirectColumnAccess() && if (columnArg.isDirectColumnAccess() &&
rowSignature.getColumnType(columnArg.getDirectColumn()) inputAccessor.getInputRowSignature()
.getColumnType(columnArg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX)) .map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) { .orElse(false)) {
fieldName = columnArg.getDirectColumn(); fieldName = columnArg.getDirectColumn();

View File

@ -20,8 +20,6 @@
package org.apache.druid.query.aggregation.bloom.sql; package org.apache.druid.query.aggregation.bloom.sql;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -38,13 +36,13 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.query.dimension.ExtractionDimensionSpec;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -65,25 +63,18 @@ public class BloomFilterSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
{ {
final RexNode inputOperand = Expressions.fromFieldAccess( final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final DruidExpression input = Expressions.toDruidExpression( final DruidExpression input = Expressions.toDruidExpression(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
inputOperand inputOperand
); );
if (input == null) { if (input == null) {
@ -92,12 +83,7 @@ public class BloomFilterSqlAggregator implements SqlAggregator
final AggregatorFactory aggregatorFactory; final AggregatorFactory aggregatorFactory;
final String aggName = StringUtils.format("%s:agg", name); final String aggName = StringUtils.format("%s:agg", name);
final RexNode maxNumEntriesOperand = Expressions.fromFieldAccess( final RexNode maxNumEntriesOperand = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
if (!maxNumEntriesOperand.isA(SqlKind.LITERAL)) { if (!maxNumEntriesOperand.isA(SqlKind.LITERAL)) {
// maxNumEntriesOperand must be a literal in order to plan. // maxNumEntriesOperand must be a literal in order to plan.

View File

@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.histogram.sql;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -38,13 +36,12 @@ import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogram;
import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogramAggregatorFactory; import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogramAggregatorFactory;
import org.apache.druid.query.aggregation.histogram.QuantilePostAggregator; import org.apache.druid.query.aggregation.histogram.QuantilePostAggregator;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -65,25 +62,18 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
{ {
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
Expressions.fromFieldAccess( inputAccessor.getField(aggregateCall.getArgList().get(0))
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
); );
if (input == null) { if (input == null) {
return null; return null;
@ -91,12 +81,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
final AggregatorFactory aggregatorFactory; final AggregatorFactory aggregatorFactory;
final String histogramName = StringUtils.format("%s:agg", name); final String histogramName = StringUtils.format("%s:agg", name);
final RexNode probabilityArg = Expressions.fromFieldAccess( final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
if (!probabilityArg.isA(SqlKind.LITERAL)) { if (!probabilityArg.isA(SqlKind.LITERAL)) {
// Probability must be a literal in order to plan. // Probability must be a literal in order to plan.
@ -107,12 +92,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
final int numBuckets; final int numBuckets;
if (aggregateCall.getArgList().size() >= 3) { if (aggregateCall.getArgList().size() >= 3) {
final RexNode numBucketsArg = Expressions.fromFieldAccess( final RexNode numBucketsArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
if (!numBucketsArg.isA(SqlKind.LITERAL)) { if (!numBucketsArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan. // Resolution must be a literal in order to plan.
@ -126,12 +106,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
final double lowerLimit; final double lowerLimit;
if (aggregateCall.getArgList().size() >= 4) { if (aggregateCall.getArgList().size() >= 4) {
final RexNode lowerLimitArg = Expressions.fromFieldAccess( final RexNode lowerLimitArg = inputAccessor.getField(aggregateCall.getArgList().get(3));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(3)
);
if (!lowerLimitArg.isA(SqlKind.LITERAL)) { if (!lowerLimitArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan. // Resolution must be a literal in order to plan.
@ -145,12 +120,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
final double upperLimit; final double upperLimit;
if (aggregateCall.getArgList().size() >= 5) { if (aggregateCall.getArgList().size() >= 5) {
final RexNode upperLimitArg = Expressions.fromFieldAccess( final RexNode upperLimitArg = inputAccessor.getField(aggregateCall.getArgList().get(4));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(4)
);
if (!upperLimitArg.isA(SqlKind.LITERAL)) { if (!upperLimitArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan. // Resolution must be a literal in order to plan.
@ -164,12 +134,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
final FixedBucketsHistogram.OutlierHandlingMode outlierHandlingMode; final FixedBucketsHistogram.OutlierHandlingMode outlierHandlingMode;
if (aggregateCall.getArgList().size() >= 6) { if (aggregateCall.getArgList().size() >= 6) {
final RexNode outlierHandlingModeArg = Expressions.fromFieldAccess( final RexNode outlierHandlingModeArg = inputAccessor.getField(aggregateCall.getArgList().get(5));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(5)
);
if (!outlierHandlingModeArg.isA(SqlKind.LITERAL)) { if (!outlierHandlingModeArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan. // Resolution must be a literal in order to plan.

View File

@ -21,8 +21,6 @@ package org.apache.druid.query.aggregation.histogram.sql;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -39,14 +37,13 @@ import org.apache.druid.query.aggregation.histogram.ApproximateHistogramAggregat
import org.apache.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory; import org.apache.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory;
import org.apache.druid.query.aggregation.histogram.QuantilePostAggregator; import org.apache.druid.query.aggregation.histogram.QuantilePostAggregator;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -67,25 +64,18 @@ public class QuantileSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
{ {
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
Expressions.fromFieldAccess( inputAccessor.getField(aggregateCall.getArgList().get(0))
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
); );
if (input == null) { if (input == null) {
return null; return null;
@ -93,12 +83,7 @@ public class QuantileSqlAggregator implements SqlAggregator
final AggregatorFactory aggregatorFactory; final AggregatorFactory aggregatorFactory;
final String histogramName = StringUtils.format("%s:agg", name); final String histogramName = StringUtils.format("%s:agg", name);
final RexNode probabilityArg = Expressions.fromFieldAccess( final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
if (!probabilityArg.isA(SqlKind.LITERAL)) { if (!probabilityArg.isA(SqlKind.LITERAL)) {
// Probability must be a literal in order to plan. // Probability must be a literal in order to plan.
@ -109,12 +94,7 @@ public class QuantileSqlAggregator implements SqlAggregator
final int resolution; final int resolution;
if (aggregateCall.getArgList().size() >= 3) { if (aggregateCall.getArgList().size() >= 3) {
final RexNode resolutionArg = Expressions.fromFieldAccess( final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(2));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
if (!resolutionArg.isA(SqlKind.LITERAL)) { if (!resolutionArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan. // Resolution must be a literal in order to plan.
@ -170,7 +150,10 @@ public class QuantileSqlAggregator implements SqlAggregator
// No existing match found. Create a new one. // No existing match found. Create a new one.
if (input.isDirectColumnAccess()) { if (input.isDirectColumnAccess()) {
if (rowSignature.getColumnType(input.getDirectColumn()).map(type -> type.is(ValueType.COMPLEX)).orElse(false)) { if (inputAccessor.getInputRowSignature()
.getColumnType(input.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) {
aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory( aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory(
histogramName, histogramName,
input.getDirectColumn(), input.getDirectColumn(),

View File

@ -21,9 +21,7 @@ package org.apache.druid.query.aggregation.variance.sql;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
@ -40,15 +38,14 @@ import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.table.RowSignatures; import org.apache.druid.sql.calcite.table.RowSignatures;
@ -77,25 +74,19 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
{ {
final RexNode inputOperand = Expressions.fromFieldAccess( final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(0)
);
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
inputOperand inputOperand
); );
if (input == null) { if (input == null) {

View File

@ -20,14 +20,13 @@
package org.apache.druid.sql.calcite.aggregation; package org.apache.druid.sql.calcite.aggregation;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
@ -48,28 +47,24 @@ public class Aggregations
* *
* 1) They can take direct field accesses or expressions as inputs. * 1) They can take direct field accesses or expressions as inputs.
* 2) They cannot implicitly cast strings to numbers when using a direct field access. * 2) They cannot implicitly cast strings to numbers when using a direct field access.
*
* @param plannerContext SQL planner context * @param plannerContext SQL planner context
* @param rowSignature input row signature
* @param call aggregate call object * @param call aggregate call object
* @param project project that should be applied before aggregation; may be null * @param inputAccessor gives access to input fields and schema
* *
* @return list of expressions corresponding to aggregator arguments, or null if any cannot be translated * @return list of expressions corresponding to aggregator arguments, or null if any cannot be translated
*/ */
@Nullable @Nullable
public static List<DruidExpression> getArgumentsForSimpleAggregator( public static List<DruidExpression> getArgumentsForSimpleAggregator(
final RexBuilder rexBuilder,
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final AggregateCall call, final AggregateCall call,
@Nullable final Project project final InputAccessor inputAccessor
) )
{ {
final List<DruidExpression> args = call final List<DruidExpression> args = call
.getArgList() .getArgList()
.stream() .stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) .map(i -> inputAccessor.getField(i))
.map(rexNode -> toDruidExpressionForNumericAggregator(plannerContext, rowSignature, rexNode)) .map(rexNode -> toDruidExpressionForNumericAggregator(plannerContext, inputAccessor.getInputRowSignature(), rexNode))
.collect(Collectors.toList()); .collect(Collectors.toList());
if (args.stream().noneMatch(Objects::isNull)) { if (args.stream().noneMatch(Objects::isNull)) {

View File

@ -20,8 +20,6 @@
package org.apache.druid.sql.calcite.aggregation; package org.apache.druid.sql.calcite.aggregation;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlKind;
@ -30,8 +28,8 @@ import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality; import org.apache.calcite.util.Optionality;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -66,24 +64,20 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
{ {
return delegate.toDruidAggregation( return delegate.toDruidAggregation(
plannerContext, plannerContext,
rowSignature,
virtualColumnRegistry, virtualColumnRegistry,
rexBuilder,
name, name,
aggregateCall, aggregateCall,
project, inputAccessor,
existingAggregations, existingAggregations,
finalizeAggregations finalizeAggregations
); );

View File

@ -25,6 +25,7 @@ import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -42,6 +43,44 @@ public interface SqlAggregator
*/ */
SqlAggFunction calciteFunction(); SqlAggFunction calciteFunction();
/**
* Returns a Druid Aggregation corresponding to a SQL {@link AggregateCall}. This method should ignore filters;
* they will be applied to your aggregator in a later step.
*
* @param plannerContext SQL planner context
* @param virtualColumnRegistry re-usable virtual column references
* @param name desired output name of the aggregation
* @param aggregateCall aggregate call object
* @param inputAccessor gives access to input fields and schema
* @param existingAggregations existing aggregations for this query; useful for re-using aggregations. May be safely
* ignored if you do not want to re-use existing aggregations.
* @param finalizeAggregations true if this query should include explicit finalization for all of its
* aggregators, where required. This is set for subqueries where Druid's native query
* layer does not do this automatically.
* @return aggregation, or null if the call cannot be translated
*/
@Nullable
default Aggregation toDruidAggregation(
PlannerContext plannerContext,
VirtualColumnRegistry virtualColumnRegistry,
String name,
AggregateCall aggregateCall,
InputAccessor inputAccessor,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
return toDruidAggregation(plannerContext,
inputAccessor.getInputRowSignature(),
virtualColumnRegistry,
inputAccessor.getRexBuilder(),
name,
aggregateCall,
inputAccessor.getProject(),
existingAggregations,
finalizeAggregations);
}
/** /**
* Returns a Druid Aggregation corresponding to a SQL {@link AggregateCall}. This method should ignore filters; * Returns a Druid Aggregation corresponding to a SQL {@link AggregateCall}. This method should ignore filters;
* they will be applied to your aggregator in a later step. * they will be applied to your aggregator in a later step.
@ -62,7 +101,7 @@ public interface SqlAggregator
* @return aggregation, or null if the call cannot be translated * @return aggregation, or null if the call cannot be translated
*/ */
@Nullable @Nullable
Aggregation toDruidAggregation( default Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature, RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
@ -72,5 +111,8 @@ public interface SqlAggregator
Project project, Project project,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
); )
{
throw new RuntimeException("unimplemented fallback method!");
}
} }

View File

@ -21,8 +21,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -39,18 +37,17 @@ import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.segment.VirtualColumn; import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class ArrayConcatSqlAggregator implements SqlAggregator public class ArrayConcatSqlAggregator implements SqlAggregator
{ {
@ -67,21 +64,15 @@ public class ArrayConcatSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
{ {
final List<RexNode> arguments = aggregateCall final List<RexNode> arguments = inputAccessor.getFields(aggregateCall.getArgList());
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.collect(Collectors.toList());
Integer maxSizeBytes = null; Integer maxSizeBytes = null;
if (arguments.size() > 1) { if (arguments.size() > 1) {
@ -92,7 +83,7 @@ public class ArrayConcatSqlAggregator implements SqlAggregator
} }
maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue();
} }
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, arguments.get(0)); final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0));
final ExprMacroTable macroTable = plannerContext.getPlannerToolbox().exprMacroTable(); final ExprMacroTable macroTable = plannerContext.getPlannerToolbox().exprMacroTable();
final String fieldName; final String fieldName;

View File

@ -21,9 +21,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -41,18 +39,17 @@ import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class ArraySqlAggregator implements SqlAggregator public class ArraySqlAggregator implements SqlAggregator
{ {
@ -69,21 +66,16 @@ public class ArraySqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
{ {
final List<RexNode> arguments = aggregateCall final List<RexNode> arguments =
.getArgList() inputAccessor.getFields(aggregateCall.getArgList());
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.collect(Collectors.toList());
Integer maxSizeBytes = null; Integer maxSizeBytes = null;
if (arguments.size() > 1) { if (arguments.size() > 1) {
@ -94,7 +86,7 @@ public class ArraySqlAggregator implements SqlAggregator
} }
maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue();
} }
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, arguments.get(0)); final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0));
if (arg == null) { if (arg == null) {
// can't translate argument // can't translate argument
return null; return null;

View File

@ -22,8 +22,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@ -33,14 +31,13 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -58,23 +55,19 @@ public class AvgSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
{ {
final List<DruidExpression> arguments = Aggregations.getArgumentsForSimpleAggregator( final List<DruidExpression> arguments = Aggregations.getArgumentsForSimpleAggregator(
rexBuilder,
plannerContext, plannerContext,
rowSignature,
aggregateCall, aggregateCall,
project inputAccessor
); );
if (arguments == null) { if (arguments == null) {
@ -85,11 +78,11 @@ public class AvgSqlAggregator implements SqlAggregator
final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory( final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory(
countName, countName,
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
virtualColumnRegistry, virtualColumnRegistry,
rexBuilder, inputAccessor.getRexBuilder(),
aggregateCall, aggregateCall,
project inputAccessor
); );
final DruidExpression arg = Iterables.getOnlyElement(arguments); final DruidExpression arg = Iterables.getOnlyElement(arguments);
@ -108,12 +101,8 @@ public class AvgSqlAggregator implements SqlAggregator
if (arg.isDirectColumnAccess()) { if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn(); fieldName = arg.getDirectColumn();
} else { } else {
final RexNode resolutionArg = Expressions.fromFieldAccess( final RexNode resolutionArg = inputAccessor.getField(
rexBuilder.getTypeFactory(), Iterables.getOnlyElement(aggregateCall.getArgList()));
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
fieldName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(arg, resolutionArg.getType()); fieldName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(arg, resolutionArg.getType());
} }

View File

@ -21,8 +21,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlKind;
@ -41,12 +39,12 @@ import org.apache.druid.query.filter.NotDimFilter;
import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.filter.NullFilter;
import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -122,12 +120,10 @@ public class BitwiseSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
@ -135,8 +131,8 @@ public class BitwiseSqlAggregator implements SqlAggregator
final List<DruidExpression> arguments = aggregateCall final List<DruidExpression> arguments = aggregateCall
.getArgList() .getArgList()
.stream() .stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) .map(i -> inputAccessor.getField(i))
.map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode)) .map(rexNode -> Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode))
.collect(Collectors.toList()); .collect(Collectors.toList());
if (arguments.stream().anyMatch(Objects::isNull)) { if (arguments.stream().anyMatch(Objects::isNull)) {

View File

@ -22,9 +22,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlFunctionCategory;
@ -42,7 +40,6 @@ import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFact
import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
@ -50,6 +47,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -72,26 +70,20 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
{ {
// Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access // Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access
// for string columns. // for string columns.
final RexNode rexNode = Expressions.fromFieldAccess( final RexNode rexNode = inputAccessor.getField(
rexBuilder.getTypeFactory(), Iterables.getOnlyElement(aggregateCall.getArgList()));
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, rexNode); final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode);
if (arg == null) { if (arg == null) {
return null; return null;
} }
@ -100,7 +92,10 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
if (arg.isDirectColumnAccess() if (arg.isDirectColumnAccess()
&& rowSignature.getColumnType(arg.getDirectColumn()).map(type -> type.is(ValueType.COMPLEX)).orElse(false)) { && inputAccessor.getInputRowSignature()
.getColumnType(arg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.orElse(false)) {
aggregatorFactory = new HyperUniquesAggregatorFactory(aggregatorName, arg.getDirectColumn(), false, true); aggregatorFactory = new HyperUniquesAggregatorFactory(aggregatorName, arg.getDirectColumn(), false, true);
} else { } else {
final RelDataType dataType = rexNode.getType(); final RelDataType dataType = rexNode.getType();

View File

@ -22,7 +22,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -40,6 +39,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -69,15 +69,10 @@ public class CountSqlAggregator implements SqlAggregator
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder, final RexBuilder rexBuilder,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project final InputAccessor inputAccessor
) )
{ {
final RexNode rexNode = Expressions.fromFieldAccess( final RexNode rexNode = inputAccessor.getField(Iterables.getOnlyElement(aggregateCall.getArgList()));
rexBuilder.getTypeFactory(),
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
if (rexNode.getType().isNullable()) { if (rexNode.getType().isNullable()) {
final DimFilter nonNullFilter = Expressions.toFilter( final DimFilter nonNullFilter = Expressions.toFilter(
@ -102,28 +97,25 @@ public class CountSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
{ {
final List<DruidExpression> args = Aggregations.getArgumentsForSimpleAggregator( final List<DruidExpression> args = Aggregations.getArgumentsForSimpleAggregator(
rexBuilder,
plannerContext, plannerContext,
rowSignature,
aggregateCall, aggregateCall,
project inputAccessor
); );
if (args == null) { if (args == null) {
return null; return null;
} }
// FIXME: is-all-literal
if (args.isEmpty()) { if (args.isEmpty()) {
// COUNT(*) // COUNT(*)
return Aggregation.create(new CountAggregatorFactory(name)); return Aggregation.create(new CountAggregatorFactory(name));
@ -132,12 +124,10 @@ public class CountSqlAggregator implements SqlAggregator
if (plannerContext.getPlannerConfig().isUseApproximateCountDistinct()) { if (plannerContext.getPlannerConfig().isUseApproximateCountDistinct()) {
return approxCountDistinctAggregator.toDruidAggregation( return approxCountDistinctAggregator.toDruidAggregation(
plannerContext, plannerContext,
rowSignature,
virtualColumnRegistry, virtualColumnRegistry,
rexBuilder,
name, name,
aggregateCall, aggregateCall,
project, inputAccessor,
existingAggregations, existingAggregations,
finalizeAggregations finalizeAggregations
); );
@ -150,11 +140,11 @@ public class CountSqlAggregator implements SqlAggregator
AggregatorFactory theCount = createCountAggregatorFactory( AggregatorFactory theCount = createCountAggregatorFactory(
name, name,
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
virtualColumnRegistry, virtualColumnRegistry,
rexBuilder, inputAccessor.getRexBuilder(),
aggregateCall, aggregateCall,
project inputAccessor
); );
return Aggregation.create(theCount); return Aggregation.create(theCount);

View File

@ -20,9 +20,7 @@
package org.apache.druid.sql.calcite.aggregation.builtin; package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -53,19 +51,18 @@ import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class EarliestLatestAnySqlAggregator implements SqlAggregator public class EarliestLatestAnySqlAggregator implements SqlAggregator
{ {
@ -180,23 +177,17 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
{ {
final List<RexNode> rexNodes = aggregateCall final List<RexNode> rexNodes = inputAccessor.getFields(aggregateCall.getArgList());
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.collect(Collectors.toList());
final List<DruidExpression> args = Expressions.toDruidExpressions(plannerContext, rowSignature, rexNodes); final List<DruidExpression> args = Expressions.toDruidExpressions(plannerContext, inputAccessor.getInputRowSignature(), rexNodes);
if (args == null) { if (args == null) {
return null; return null;
@ -216,7 +207,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
final String fieldName = getColumnName(plannerContext, virtualColumnRegistry, args.get(0), rexNodes.get(0)); final String fieldName = getColumnName(plannerContext, virtualColumnRegistry, args.get(0), rexNodes.get(0));
if (!rowSignature.contains(ColumnHolder.TIME_COLUMN_NAME) && (aggregatorType == AggregatorType.LATEST || aggregatorType == AggregatorType.EARLIEST)) { if (!inputAccessor.getInputRowSignature().contains(ColumnHolder.TIME_COLUMN_NAME)
&& (aggregatorType == AggregatorType.LATEST || aggregatorType == AggregatorType.EARLIEST)) {
// This code is being run as part of the exploratory volcano planner, currently, the definition of these // This code is being run as part of the exploratory volcano planner, currently, the definition of these
// aggregators does not tell Calcite that they depend on a __time column being in existence, instead we are // aggregators does not tell Calcite that they depend on a __time column being in existence, instead we are
// allowing the volcano planner to explore paths that put projections which eliminate the time column in between // allowing the volcano planner to explore paths that put projections which eliminate the time column in between

View File

@ -20,8 +20,6 @@
package org.apache.druid.sql.calcite.aggregation.builtin; package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -38,19 +36,18 @@ import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class EarliestLatestBySqlAggregator implements SqlAggregator public class EarliestLatestBySqlAggregator implements SqlAggregator
{ {
@ -76,23 +73,17 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
{ {
final List<RexNode> rexNodes = aggregateCall final List<RexNode> rexNodes = inputAccessor.getFields(aggregateCall.getArgList());
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i))
.collect(Collectors.toList());
final List<DruidExpression> args = Expressions.toDruidExpressions(plannerContext, rowSignature, rexNodes); final List<DruidExpression> args = Expressions.toDruidExpressions(plannerContext, inputAccessor.getInputRowSignature(), rexNodes);
if (args == null) { if (args == null) {
return null; return null;

View File

@ -22,7 +22,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@ -34,6 +33,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -53,24 +53,22 @@ public class GroupingSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, final InputAccessor inputAccessor,
List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
boolean finalizeAggregations final boolean finalizeAggregations
) )
{ {
List<String> arguments = aggregateCall.getArgList() List<String> arguments = aggregateCall.getArgList()
.stream() .stream()
.map(i -> getColumnName( .map(i -> getColumnName(
plannerContext, plannerContext,
rowSignature, inputAccessor.getInputRowSignature(),
project, inputAccessor.getProject(),
virtualColumnRegistry, virtualColumnRegistry,
rexBuilder.getTypeFactory(), inputAccessor.getRexBuilder().getTypeFactory(),
i i
)) ))
.filter(Objects::nonNull) .filter(Objects::nonNull)

View File

@ -21,18 +21,16 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlInternalOperators; import org.apache.calcite.sql.fun.SqlInternalOperators;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator; import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -59,12 +57,10 @@ public class LiteralSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
@ -73,7 +69,7 @@ public class LiteralSqlAggregator implements SqlAggregator
return null; return null;
} }
final RexNode literal = aggregateCall.rexList.get(0); final RexNode literal = aggregateCall.rexList.get(0);
final DruidExpression expr = Expressions.toDruidExpression(plannerContext, rowSignature, literal); final DruidExpression expr = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), literal);
if (expr == null) { if (expr == null) {
return null; return null;

View File

@ -21,18 +21,16 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.druid.error.DruidException; import org.apache.druid.error.DruidException;
import org.apache.druid.error.InvalidSqlInput; import org.apache.druid.error.InvalidSqlInput;
import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -57,12 +55,10 @@ public abstract class SimpleSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext, final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry, final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final boolean finalizeAggregations final boolean finalizeAggregations
) )
@ -72,11 +68,9 @@ public abstract class SimpleSqlAggregator implements SqlAggregator
} }
final List<DruidExpression> arguments = Aggregations.getArgumentsForSimpleAggregator( final List<DruidExpression> arguments = Aggregations.getArgumentsForSimpleAggregator(
rexBuilder,
plannerContext, plannerContext,
rowSignature,
aggregateCall, aggregateCall,
project inputAccessor
); );
if (arguments == null) { if (arguments == null) {

View File

@ -21,9 +21,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
@ -47,13 +45,13 @@ import org.apache.druid.query.filter.NotDimFilter;
import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.filter.NullFilter;
import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.table.RowSignatures; import org.apache.druid.sql.calcite.table.RowSignatures;
@ -89,12 +87,10 @@ public class StringSqlAggregator implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )
@ -102,20 +98,15 @@ public class StringSqlAggregator implements SqlAggregator
final List<DruidExpression> arguments = aggregateCall final List<DruidExpression> arguments = aggregateCall
.getArgList() .getArgList()
.stream() .stream()
.map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) .map(i -> inputAccessor.getField(i))
.map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode)) .map(rexNode -> Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode))
.collect(Collectors.toList()); .collect(Collectors.toList());
if (arguments.stream().anyMatch(Objects::isNull)) { if (arguments.stream().anyMatch(Objects::isNull)) {
return null; return null;
} }
RexNode separatorNode = Expressions.fromFieldAccess( RexNode separatorNode = inputAccessor.getField(aggregateCall.getArgList().get(1));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
if (!separatorNode.isA(SqlKind.LITERAL)) { if (!separatorNode.isA(SqlKind.LITERAL)) {
// separator must be a literal // separator must be a literal
return null; return null;
@ -133,12 +124,7 @@ public class StringSqlAggregator implements SqlAggregator
Integer maxSizeBytes = null; Integer maxSizeBytes = null;
if (arguments.size() > 2) { if (arguments.size() > 2) {
RexNode maxBytes = Expressions.fromFieldAccess( RexNode maxBytes = inputAccessor.getField(aggregateCall.getArgList().get(2));
rexBuilder.getTypeFactory(),
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
if (!maxBytes.isA(SqlKind.LITERAL)) { if (!maxBytes.isA(SqlKind.LITERAL)) {
// maxBytes must be a literal // maxBytes must be a literal
return null; return null;

View File

@ -20,14 +20,12 @@
package org.apache.druid.sql.calcite.expression; package org.apache.druid.sql.calcite.expression;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.druid.java.util.common.UOE; import org.apache.druid.java.util.common.UOE;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -55,12 +53,10 @@ public class WindowSqlAggregate implements SqlAggregator
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
PlannerContext plannerContext, PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry, VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name, String name,
AggregateCall aggregateCall, AggregateCall aggregateCall,
Project project, InputAccessor inputAccessor,
List<Aggregation> existingAggregations, List<Aggregation> existingAggregations,
boolean finalizeAggregations boolean finalizeAggregations
) )

View File

@ -580,7 +580,11 @@ public class DruidQuery
rowSignature, rowSignature,
virtualColumnRegistry, virtualColumnRegistry,
rexBuilder, rexBuilder,
InputAccessor.buildFor(
rexBuilder,
rowSignature,
partialQuery.getSelectProject(), partialQuery.getSelectProject(),
null),
aggregations, aggregations,
aggName, aggName,
aggCall, aggCall,

View File

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.rel;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.Expressions;
import javax.annotation.Nullable;
import java.util.List;
import java.util.stream.Collectors;
/**
* Enables simpler access to input expressions.
*
* In case of aggregates it provides the constants transparently for aggregates.
*/
public class InputAccessor
{
private final Project project;
private final ImmutableList<RexLiteral> constants;
private final RexBuilder rexBuilder;
private final RowSignature inputRowSignature;
private final int inputFieldCount;
public static InputAccessor buildFor(
RexBuilder rexBuilder,
RowSignature inputRowSignature,
@Nullable Project project,
@Nullable ImmutableList<RexLiteral> constants)
{
return new InputAccessor(rexBuilder, inputRowSignature, project, constants);
}
private InputAccessor(
RexBuilder rexBuilder,
RowSignature inputRowSignature,
Project project,
ImmutableList<RexLiteral> constants)
{
this.rexBuilder = rexBuilder;
this.inputRowSignature = inputRowSignature;
this.project = project;
this.constants = constants;
this.inputFieldCount = project != null ? project.getRowType().getFieldCount() : inputRowSignature.size();
}
public RexNode getField(int argIndex)
{
if (argIndex < inputFieldCount) {
return Expressions.fromFieldAccess(
rexBuilder.getTypeFactory(),
inputRowSignature,
project,
argIndex);
} else {
return constants.get(argIndex - inputFieldCount);
}
}
public List<RexNode> getFields(List<Integer> argList)
{
return argList
.stream()
.map(i -> getField(i))
.collect(Collectors.toList());
}
public @Nullable Project getProject()
{
return project;
}
public RexBuilder getRexBuilder()
{
return rexBuilder;
}
public RowSignature getInputRowSignature()
{
return inputRowSignature;
}
}

View File

@ -177,7 +177,11 @@ public class Windowing
sourceRowSignature, sourceRowSignature,
null, null,
rexBuilder, rexBuilder,
InputAccessor.buildFor(
rexBuilder,
sourceRowSignature,
partialQuery.getSelectProject(), partialQuery.getSelectProject(),
window.constants),
Collections.emptyList(), Collections.emptyList(),
aggName, aggName,
aggregateCall, aggregateCall,

View File

@ -20,7 +20,6 @@
package org.apache.druid.sql.calcite.rule; package org.apache.druid.sql.calcite.rule;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
@ -32,6 +31,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.filtration.Filtration; import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -58,7 +58,7 @@ public class GroupByRules
final RowSignature rowSignature, final RowSignature rowSignature,
@Nullable final VirtualColumnRegistry virtualColumnRegistry, @Nullable final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder, final RexBuilder rexBuilder,
final Project project, final InputAccessor inputAccessor,
final List<Aggregation> existingAggregations, final List<Aggregation> existingAggregations,
final String name, final String name,
final AggregateCall call, final AggregateCall call,
@ -74,11 +74,7 @@ public class GroupByRules
if (call.filterArg >= 0) { if (call.filterArg >= 0) {
// AGG(xxx) FILTER(WHERE yyy) // AGG(xxx) FILTER(WHERE yyy)
final RexNode expression = Expressions.fromFieldAccess( final RexNode expression = inputAccessor.getField(call.filterArg);
rexBuilder.getTypeFactory(),
rowSignature,
project,
call.filterArg);
final DimFilter nonOptimizedFilter = Expressions.toFilter( final DimFilter nonOptimizedFilter = Expressions.toFilter(
plannerContext, plannerContext,
@ -136,12 +132,10 @@ public class GroupByRules
final Aggregation retVal = sqlAggregator.toDruidAggregation( final Aggregation retVal = sqlAggregator.toDruidAggregation(
plannerContext, plannerContext,
rowSignature,
virtualColumnRegistry, virtualColumnRegistry,
rexBuilder,
name, name,
call, call,
project, inputAccessor,
existingAggregationsWithSameFilter, existingAggregationsWithSameFilter,
finalizeAggregations finalizeAggregations
); );

View File

@ -0,0 +1,26 @@
type: "operatorValidation"
sql: |
SELECT
dim1,
count(333) OVER () cc
FROM foo
WHERE length(dim1)>0
expectedOperators:
- type: naivePartition
partitionColumns: []
- type: "window"
processor:
type: "framedAgg"
frame: { peerType: "ROWS", lowUnbounded: true, lowOffset: 0, uppUnbounded: true, uppOffset: 0 }
aggregations:
- { type: "count", name: "w0" }
expectedResults:
- ["10.1",5]
- ["2",5]
- ["1",5]
- ["def",5]
- ["abc",5]