fix count and average SQL aggregators on constant virtual columns (#11208)

* fix count and average SQL aggregators on constant virtual columns

* style

* even better, why are we tracking virtual columns in aggregations at all if we have a virtual column registry

* oops missed a few

* remove unused

* this will fix it
This commit is contained in:
Clint Wylie 2021-05-10 13:41:48 -07:00 committed by GitHub
parent 15de29a2c4
commit f6662b4893
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 203 additions and 221 deletions

View File

@ -24,8 +24,8 @@ import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.VirtualColumn; import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -80,30 +80,24 @@ public class TDigestSketchUtils
} }
public static boolean matchingAggregatorFactoryExists( public static boolean matchingAggregatorFactoryExists(
final VirtualColumnRegistry virtualColumnRegistry,
final DruidExpression input, final DruidExpression input,
final Integer compression, final Integer compression,
final Aggregation existing,
final TDigestSketchAggregatorFactory factory final TDigestSketchAggregatorFactory factory
) )
{ {
// Check input for equivalence. // Check input for equivalence.
final boolean inputMatches; final boolean inputMatches;
final VirtualColumn virtualInput = existing.getVirtualColumns() final VirtualColumn virtualInput =
.stream() virtualColumnRegistry.findVirtualColumns(factory.requiredFields())
.filter( .stream()
virtualColumn -> .findFirst()
virtualColumn.getOutputName() .orElse(null);
.equals(factory.getFieldName())
)
.findFirst()
.orElse(null);
if (virtualInput == null) { if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess() inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(factory.getFieldName());
&& input.getDirectColumn().equals(factory.getFieldName());
} else { } else {
inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression() inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression().equals(input.getExpression());
.equals(input.getExpression());
} }
return inputMatches && compression == factory.getCompression(); return inputMatches && compression == factory.getCompression();
} }

View File

@ -47,7 +47,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public class TDigestGenerateSketchSqlAggregator implements SqlAggregator public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
@ -112,9 +111,9 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
if (factory instanceof TDigestSketchAggregatorFactory) { if (factory instanceof TDigestSketchAggregatorFactory) {
final TDigestSketchAggregatorFactory theFactory = (TDigestSketchAggregatorFactory) factory; final TDigestSketchAggregatorFactory theFactory = (TDigestSketchAggregatorFactory) factory;
final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists( final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists(
virtualColumnRegistry,
input, input,
compression, compression,
existing,
(TDigestSketchAggregatorFactory) factory (TDigestSketchAggregatorFactory) factory
); );
@ -129,8 +128,6 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
} }
// No existing match found. Create a new one. // No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) { if (input.isDirectColumnAccess()) {
aggregatorFactory = new TDigestSketchAggregatorFactory( aggregatorFactory = new TDigestSketchAggregatorFactory(
aggName, aggName,
@ -143,7 +140,6 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
input, input,
ValueType.FLOAT ValueType.FLOAT
); );
virtualColumns.add(virtualColumn);
aggregatorFactory = new TDigestSketchAggregatorFactory( aggregatorFactory = new TDigestSketchAggregatorFactory(
aggName, aggName,
virtualColumn.getOutputName(), virtualColumn.getOutputName(),
@ -151,10 +147,7 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
); );
} }
return Aggregation.create( return Aggregation.create(aggregatorFactory);
virtualColumns,
aggregatorFactory
);
} }
private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction

View File

@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public class TDigestSketchQuantileSqlAggregator implements SqlAggregator public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
@ -123,9 +122,9 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
for (AggregatorFactory factory : existing.getAggregatorFactories()) { for (AggregatorFactory factory : existing.getAggregatorFactories()) {
if (factory instanceof TDigestSketchAggregatorFactory) { if (factory instanceof TDigestSketchAggregatorFactory) {
final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists( final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists(
virtualColumnRegistry,
input, input,
compression, compression,
existing,
(TDigestSketchAggregatorFactory) factory (TDigestSketchAggregatorFactory) factory
); );
@ -148,8 +147,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
} }
// No existing match found. Create a new one. // No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) { if (input.isDirectColumnAccess()) {
aggregatorFactory = new TDigestSketchAggregatorFactory( aggregatorFactory = new TDigestSketchAggregatorFactory(
sketchName, sketchName,
@ -162,7 +159,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
input, input,
ValueType.FLOAT ValueType.FLOAT
); );
virtualColumns.add(virtualColumn);
aggregatorFactory = new TDigestSketchAggregatorFactory( aggregatorFactory = new TDigestSketchAggregatorFactory(
sketchName, sketchName,
virtualColumn.getOutputName(), virtualColumn.getOutputName(),
@ -171,7 +167,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
} }
return Aggregation.create( return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory), ImmutableList.of(aggregatorFactory),
new TDigestSketchToQuantilePostAggregator( new TDigestSketchToQuantilePostAggregator(
name, name,

View File

@ -29,12 +29,10 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections; import java.util.Collections;
import java.util.List;
public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{ {
@ -51,12 +49,10 @@ public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlA
protected Aggregation toAggregation( protected Aggregation toAggregation(
String name, String name,
boolean finalizeAggregations, boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory AggregatorFactory aggregatorFactory
) )
{ {
return Aggregation.create( return Aggregation.create(
virtualColumns,
Collections.singletonList(aggregatorFactory), Collections.singletonList(aggregatorFactory),
finalizeAggregations ? new FinalizingFieldAccessPostAggregator( finalizeAggregations ? new FinalizingFieldAccessPostAggregator(
name, name,

View File

@ -45,7 +45,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public abstract class HllSketchBaseSqlAggregator implements SqlAggregator public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
@ -115,7 +114,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
tgtHllType = HllSketchAggregatorFactory.DEFAULT_TGT_HLL_TYPE.name(); tgtHllType = HllSketchAggregatorFactory.DEFAULT_TGT_HLL_TYPE.name();
} }
final List<VirtualColumn> virtualColumns = new ArrayList<>();
final AggregatorFactory aggregatorFactory; final AggregatorFactory aggregatorFactory;
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
@ -150,7 +148,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
dataType dataType
); );
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType); dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
virtualColumns.add(virtualColumn);
} }
aggregatorFactory = new HllSketchBuildAggregatorFactory( aggregatorFactory = new HllSketchBuildAggregatorFactory(
@ -165,7 +162,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
return toAggregation( return toAggregation(
name, name,
finalizeAggregations, finalizeAggregations,
virtualColumns,
aggregatorFactory aggregatorFactory
); );
} }
@ -173,7 +169,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
protected abstract Aggregation toAggregation( protected abstract Aggregation toAggregation(
String name, String name,
boolean finalizeAggregations, boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory AggregatorFactory aggregatorFactory
); );
} }

View File

@ -28,12 +28,10 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections; import java.util.Collections;
import java.util.List;
public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{ {
@ -50,12 +48,10 @@ public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator imp
protected Aggregation toAggregation( protected Aggregation toAggregation(
String name, String name,
boolean finalizeAggregations, boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory AggregatorFactory aggregatorFactory
) )
{ {
return Aggregation.create( return Aggregation.create(
virtualColumns,
Collections.singletonList(aggregatorFactory), Collections.singletonList(aggregatorFactory),
null null
); );

View File

@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
@ -132,22 +131,16 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
// Check input for equivalence. // Check input for equivalence.
final boolean inputMatches; final boolean inputMatches;
final VirtualColumn virtualInput = existing.getVirtualColumns() final VirtualColumn virtualInput =
.stream() virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
.filter( .stream()
virtualColumn -> .findFirst()
virtualColumn.getOutputName() .orElse(null);
.equals(theFactory.getFieldName())
)
.findFirst()
.orElse(null);
if (virtualInput == null) { if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess() inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(theFactory.getFieldName());
&& input.getDirectColumn().equals(theFactory.getFieldName());
} else { } else {
inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression() inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression().equals(input.getExpression());
.equals(input.getExpression());
} }
final boolean matches = inputMatches final boolean matches = inputMatches
@ -172,8 +165,6 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
} }
// No existing match found. Create a new one. // No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) { if (input.isDirectColumnAccess()) {
aggregatorFactory = new DoublesSketchAggregatorFactory( aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName, histogramName,
@ -186,7 +177,6 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
input, input,
ValueType.FLOAT ValueType.FLOAT
); );
virtualColumns.add(virtualColumn);
aggregatorFactory = new DoublesSketchAggregatorFactory( aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName, histogramName,
virtualColumn.getOutputName(), virtualColumn.getOutputName(),
@ -195,7 +185,6 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
} }
return Aggregation.create( return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory), ImmutableList.of(aggregatorFactory),
new DoublesSketchToQuantilePostAggregator( new DoublesSketchToQuantilePostAggregator(
name, name,

View File

@ -47,7 +47,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public class DoublesSketchObjectSqlAggregator implements SqlAggregator public class DoublesSketchObjectSqlAggregator implements SqlAggregator
@ -110,8 +109,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
} }
// No existing match found. Create a new one. // No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) { if (input.isDirectColumnAccess()) {
aggregatorFactory = new DoublesSketchAggregatorFactory( aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName, histogramName,
@ -124,7 +121,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
input, input,
ValueType.FLOAT ValueType.FLOAT
); );
virtualColumns.add(virtualColumn);
aggregatorFactory = new DoublesSketchAggregatorFactory( aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName, histogramName,
virtualColumn.getOutputName(), virtualColumn.getOutputName(),
@ -133,7 +129,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
} }
return Aggregation.create( return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory), ImmutableList.of(aggregatorFactory),
null null
); );

View File

@ -29,12 +29,10 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections; import java.util.Collections;
import java.util.List;
public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{ {
@ -51,12 +49,10 @@ public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBase
protected Aggregation toAggregation( protected Aggregation toAggregation(
String name, String name,
boolean finalizeAggregations, boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory AggregatorFactory aggregatorFactory
) )
{ {
return Aggregation.create( return Aggregation.create(
virtualColumns,
Collections.singletonList(aggregatorFactory), Collections.singletonList(aggregatorFactory),
finalizeAggregations ? new FinalizingFieldAccessPostAggregator( finalizeAggregations ? new FinalizingFieldAccessPostAggregator(
name, name,

View File

@ -44,7 +44,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
@ -94,7 +93,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
sketchSize = SketchAggregatorFactory.DEFAULT_MAX_SKETCH_SIZE; sketchSize = SketchAggregatorFactory.DEFAULT_MAX_SKETCH_SIZE;
} }
final List<VirtualColumn> virtualColumns = new ArrayList<>();
final AggregatorFactory aggregatorFactory; final AggregatorFactory aggregatorFactory;
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
@ -130,7 +128,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
dataType dataType
); );
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType); dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
virtualColumns.add(virtualColumn);
} }
aggregatorFactory = new SketchMergeAggregatorFactory( aggregatorFactory = new SketchMergeAggregatorFactory(
@ -146,7 +143,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
return toAggregation( return toAggregation(
name, name,
finalizeAggregations, finalizeAggregations,
virtualColumns,
aggregatorFactory aggregatorFactory
); );
} }
@ -154,7 +150,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
protected abstract Aggregation toAggregation( protected abstract Aggregation toAggregation(
String name, String name,
boolean finalizeAggregations, boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory AggregatorFactory aggregatorFactory
); );
} }

View File

@ -28,12 +28,10 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections; import java.util.Collections;
import java.util.List;
public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{ {
@ -50,12 +48,10 @@ public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator
protected Aggregation toAggregation( protected Aggregation toAggregation(
String name, String name,
boolean finalizeAggregations, boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory AggregatorFactory aggregatorFactory
) )
{ {
return Aggregation.create( return Aggregation.create(
virtualColumns,
Collections.singletonList(aggregatorFactory), Collections.singletonList(aggregatorFactory),
null null
); );

View File

@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public class BloomFilterSqlAggregator implements SqlAggregator public class BloomFilterSqlAggregator implements SqlAggregator
@ -115,15 +114,10 @@ public class BloomFilterSqlAggregator implements SqlAggregator
// Check input for equivalence. // Check input for equivalence.
final boolean inputMatches; final boolean inputMatches;
final VirtualColumn virtualInput = final VirtualColumn virtualInput = virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
existing.getVirtualColumns() .stream()
.stream() .findFirst()
.filter(virtualColumn -> .orElse(null);
virtualColumn.getOutputName().equals(theFactory.getField().getOutputName())
)
.findFirst()
.orElse(null);
if (virtualInput == null) { if (virtualInput == null) {
if (input.isDirectColumnAccess()) { if (input.isDirectColumnAccess()) {
inputMatches = inputMatches =
@ -150,7 +144,6 @@ public class BloomFilterSqlAggregator implements SqlAggregator
} }
// No existing match found. Create a new one. // No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
ValueType valueType = Calcites.getValueTypeForRelDataType(inputOperand.getType()); ValueType valueType = Calcites.getValueTypeForRelDataType(inputOperand.getType());
final DimensionSpec spec; final DimensionSpec spec;
@ -173,7 +166,6 @@ public class BloomFilterSqlAggregator implements SqlAggregator
input, input,
inputOperand.getType() inputOperand.getType()
); );
virtualColumns.add(virtualColumn);
spec = new DefaultDimensionSpec( spec = new DefaultDimensionSpec(
virtualColumn.getOutputName(), virtualColumn.getOutputName(),
StringUtils.format("%s:%s", name, virtualColumn.getOutputName()) StringUtils.format("%s:%s", name, virtualColumn.getOutputName())
@ -186,10 +178,7 @@ public class BloomFilterSqlAggregator implements SqlAggregator
maxNumEntries maxNumEntries
); );
return Aggregation.create( return Aggregation.create(aggregatorFactory);
virtualColumns,
aggregatorFactory
);
} }
private static class BloomFilterSqlAggFunction extends SqlAggFunction private static class BloomFilterSqlAggFunction extends SqlAggFunction

View File

@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
@ -188,15 +187,11 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
// Check input for equivalence. // Check input for equivalence.
final boolean inputMatches; final boolean inputMatches;
final VirtualColumn virtualInput = existing.getVirtualColumns() final VirtualColumn virtualInput =
.stream() virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
.filter( .stream()
virtualColumn -> .findFirst()
virtualColumn.getOutputName() .orElse(null);
.equals(theFactory.getFieldName())
)
.findFirst()
.orElse(null);
if (virtualInput == null) { if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess() inputMatches = input.isDirectColumnAccess()
@ -224,8 +219,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
} }
// No existing match found. Create a new one. // No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) { if (input.isDirectColumnAccess()) {
aggregatorFactory = new FixedBucketsHistogramAggregatorFactory( aggregatorFactory = new FixedBucketsHistogramAggregatorFactory(
histogramName, histogramName,
@ -242,7 +235,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
input, input,
ValueType.FLOAT ValueType.FLOAT
); );
virtualColumns.add(virtualColumn);
aggregatorFactory = new FixedBucketsHistogramAggregatorFactory( aggregatorFactory = new FixedBucketsHistogramAggregatorFactory(
histogramName, histogramName,
virtualColumn.getOutputName(), virtualColumn.getOutputName(),
@ -255,7 +247,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
} }
return Aggregation.create( return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory), ImmutableList.of(aggregatorFactory),
new QuantilePostAggregator(name, histogramName, probability) new QuantilePostAggregator(name, histogramName, probability)
); );

View File

@ -51,7 +51,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public class QuantileSqlAggregator implements SqlAggregator public class QuantileSqlAggregator implements SqlAggregator
@ -137,15 +136,11 @@ public class QuantileSqlAggregator implements SqlAggregator
// Check input for equivalence. // Check input for equivalence.
final boolean inputMatches; final boolean inputMatches;
final VirtualColumn virtualInput = existing.getVirtualColumns() final VirtualColumn virtualInput =
.stream() virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
.filter( .stream()
virtualColumn -> .findFirst()
virtualColumn.getOutputName() .orElse(null);
.equals(theFactory.getFieldName())
)
.findFirst()
.orElse(null);
if (virtualInput == null) { if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess() inputMatches = input.isDirectColumnAccess()
@ -173,8 +168,6 @@ public class QuantileSqlAggregator implements SqlAggregator
} }
// No existing match found. Create a new one. // No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) { if (input.isDirectColumnAccess()) {
if (rowSignature.getColumnType(input.getDirectColumn()).orElse(null) == ValueType.COMPLEX) { if (rowSignature.getColumnType(input.getDirectColumn()).orElse(null) == ValueType.COMPLEX) {
aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory( aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory(
@ -200,7 +193,6 @@ public class QuantileSqlAggregator implements SqlAggregator
} else { } else {
final VirtualColumn virtualColumn = final VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, ValueType.FLOAT); virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, ValueType.FLOAT);
virtualColumns.add(virtualColumn);
aggregatorFactory = new ApproximateHistogramAggregatorFactory( aggregatorFactory = new ApproximateHistogramAggregatorFactory(
histogramName, histogramName,
virtualColumn.getOutputName(), virtualColumn.getOutputName(),
@ -213,7 +205,6 @@ public class QuantileSqlAggregator implements SqlAggregator
} }
return Aggregation.create( return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory), ImmutableList.of(aggregatorFactory),
new QuantilePostAggregator(name, histogramName, probability) new QuantilePostAggregator(name, histogramName, probability)
); );

View File

@ -48,7 +48,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public abstract class BaseVarianceSqlAggregator implements SqlAggregator public abstract class BaseVarianceSqlAggregator implements SqlAggregator
@ -84,7 +83,6 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
final AggregatorFactory aggregatorFactory; final AggregatorFactory aggregatorFactory;
final RelDataType dataType = inputOperand.getType(); final RelDataType dataType = inputOperand.getType();
final ValueType inputType = Calcites.getValueTypeForRelDataType(dataType); final ValueType inputType = Calcites.getValueTypeForRelDataType(dataType);
final List<VirtualColumn> virtualColumns = new ArrayList<>();
final DimensionSpec dimensionSpec; final DimensionSpec dimensionSpec;
final String aggName = StringUtils.format("%s:agg", name); final String aggName = StringUtils.format("%s:agg", name);
final SqlAggFunction func = calciteFunction(); final SqlAggFunction func = calciteFunction();
@ -98,7 +96,6 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
VirtualColumn virtualColumn = VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, dataType); virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, dataType);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType); dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
virtualColumns.add(virtualColumn);
} }
switch (inputType) { switch (inputType) {
@ -135,7 +132,6 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
} }
return Aggregation.create( return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory), ImmutableList.of(aggregatorFactory),
postAggregator postAggregator
); );

View File

@ -23,6 +23,8 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.query.PerSegmentQueryOptimizationContext; import org.apache.druid.query.PerSegmentQueryOptimizationContext;
import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.Filter; import org.apache.druid.query.filter.Filter;
@ -166,7 +168,10 @@ public class FilteredAggregatorFactory extends AggregatorFactory
@Override @Override
public List<String> requiredFields() public List<String> requiredFields()
{ {
return delegate.requiredFields(); return ImmutableList.copyOf(
// use a set to get rid of dupes
ImmutableSet.<String>builder().addAll(delegate.requiredFields()).addAll(filter.getRequiredColumns()).build()
);
} }
@Override @Override

View File

@ -19,11 +19,14 @@
package org.apache.druid.query.aggregation; package org.apache.druid.query.aggregation;
import com.google.common.collect.ImmutableList;
import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.query.filter.TrueDimFilter; import org.apache.druid.query.filter.TrueDimFilter;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
public class FilteredAggregatorFactoryTest public class FilteredAggregatorFactoryTest extends InitializedNullHandlingTest
{ {
@Test @Test
public void testSimpleNaming() public void testSimpleNaming()
@ -44,4 +47,16 @@ public class FilteredAggregatorFactoryTest
null null
).getName()); ).getName());
} }
@Test
public void testRequiredFields()
{
Assert.assertEquals(
ImmutableList.of("x", "y"),
new FilteredAggregatorFactory(
new LongSumAggregatorFactory("x", "x"),
new SelectorDimFilter("y", "wat", null)
).requiredFields()
);
}
} }

View File

@ -29,7 +29,6 @@ import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.filter.AndDimFilter; import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.filtration.Filtration; import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
@ -44,17 +43,14 @@ import java.util.Set;
public class Aggregation public class Aggregation
{ {
private final List<VirtualColumn> virtualColumns;
private final List<AggregatorFactory> aggregatorFactories; private final List<AggregatorFactory> aggregatorFactories;
private final PostAggregator postAggregator; private final PostAggregator postAggregator;
private Aggregation( private Aggregation(
final List<VirtualColumn> virtualColumns,
final List<AggregatorFactory> aggregatorFactories, final List<AggregatorFactory> aggregatorFactories,
final PostAggregator postAggregator final PostAggregator postAggregator
) )
{ {
this.virtualColumns = Preconditions.checkNotNull(virtualColumns, "virtualColumns");
this.aggregatorFactories = Preconditions.checkNotNull(aggregatorFactories, "aggregatorFactories"); this.aggregatorFactories = Preconditions.checkNotNull(aggregatorFactories, "aggregatorFactories");
this.postAggregator = postAggregator; this.postAggregator = postAggregator;
@ -88,19 +84,10 @@ public class Aggregation
} }
} }
public static Aggregation create(final List<VirtualColumn> virtualColumns, final AggregatorFactory aggregatorFactory)
{
return new Aggregation(
virtualColumns,
ImmutableList.of(aggregatorFactory),
null
);
}
public static Aggregation create(final AggregatorFactory aggregatorFactory) public static Aggregation create(final AggregatorFactory aggregatorFactory)
{ {
return new Aggregation( return new Aggregation(
ImmutableList.of(),
ImmutableList.of(aggregatorFactory), ImmutableList.of(aggregatorFactory),
null null
); );
@ -108,7 +95,7 @@ public class Aggregation
public static Aggregation create(final PostAggregator postAggregator) public static Aggregation create(final PostAggregator postAggregator)
{ {
return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator); return new Aggregation(Collections.emptyList(), postAggregator);
} }
public static Aggregation create( public static Aggregation create(
@ -116,21 +103,19 @@ public class Aggregation
final PostAggregator postAggregator final PostAggregator postAggregator
) )
{ {
return new Aggregation(ImmutableList.of(), aggregatorFactories, postAggregator); return new Aggregation(aggregatorFactories, postAggregator);
} }
public static Aggregation create( public List<String> getRequiredColumns()
final List<VirtualColumn> virtualColumns,
final List<AggregatorFactory> aggregatorFactories,
final PostAggregator postAggregator
)
{ {
return new Aggregation(virtualColumns, aggregatorFactories, postAggregator); Set<String> columns = new HashSet<>();
} for (AggregatorFactory agg : aggregatorFactories) {
columns.addAll(agg.requiredFields());
public List<VirtualColumn> getVirtualColumns() }
{ if (postAggregator != null) {
return virtualColumns; columns.addAll(postAggregator.getDependentFields());
}
return ImmutableList.copyOf(columns);
} }
public List<AggregatorFactory> getAggregatorFactories() public List<AggregatorFactory> getAggregatorFactories()
@ -181,21 +166,10 @@ public class Aggregation
.optimizeFilterOnly(virtualColumnRegistry.getFullRowSignature()) .optimizeFilterOnly(virtualColumnRegistry.getFullRowSignature())
.getDimFilter(); .getDimFilter();
Set<VirtualColumn> aggVirtualColumnsPlusFilterColumns = new HashSet<>(virtualColumns);
for (String column : baseOptimizedFilter.getRequiredColumns()) {
if (virtualColumnRegistry.isVirtualColumnDefined(column)) {
aggVirtualColumnsPlusFilterColumns.add(virtualColumnRegistry.getVirtualColumn(column));
}
}
final List<AggregatorFactory> newAggregators = new ArrayList<>(); final List<AggregatorFactory> newAggregators = new ArrayList<>();
for (AggregatorFactory agg : aggregatorFactories) { for (AggregatorFactory agg : aggregatorFactories) {
if (agg instanceof FilteredAggregatorFactory) { if (agg instanceof FilteredAggregatorFactory) {
final FilteredAggregatorFactory filteredAgg = (FilteredAggregatorFactory) agg; final FilteredAggregatorFactory filteredAgg = (FilteredAggregatorFactory) agg;
for (String column : filteredAgg.getFilter().getRequiredColumns()) {
if (virtualColumnRegistry.isVirtualColumnDefined(column)) {
aggVirtualColumnsPlusFilterColumns.add(virtualColumnRegistry.getVirtualColumn(column));
}
}
newAggregators.add( newAggregators.add(
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
filteredAgg.getAggregator(), filteredAgg.getAggregator(),
@ -209,7 +183,7 @@ public class Aggregation
} }
} }
return new Aggregation(new ArrayList<>(aggVirtualColumnsPlusFilterColumns), newAggregators, postAggregator); return new Aggregation(newAggregators, postAggregator);
} }
@Override @Override
@ -222,23 +196,21 @@ public class Aggregation
return false; return false;
} }
final Aggregation that = (Aggregation) o; final Aggregation that = (Aggregation) o;
return Objects.equals(virtualColumns, that.virtualColumns) && return Objects.equals(aggregatorFactories, that.aggregatorFactories) &&
Objects.equals(aggregatorFactories, that.aggregatorFactories) &&
Objects.equals(postAggregator, that.postAggregator); Objects.equals(postAggregator, that.postAggregator);
} }
@Override @Override
public int hashCode() public int hashCode()
{ {
return Objects.hash(virtualColumns, aggregatorFactories, postAggregator); return Objects.hash(aggregatorFactories, postAggregator);
} }
@Override @Override
public String toString() public String toString()
{ {
return "Aggregation{" + return "Aggregation{" +
"virtualColumns=" + virtualColumns + "aggregatorFactories=" + aggregatorFactories +
", aggregatorFactories=" + aggregatorFactories +
", postAggregator=" + postAggregator + ", postAggregator=" + postAggregator +
'}'; '}';
} }

View File

@ -52,7 +52,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -94,7 +93,6 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
return null; return null;
} }
final List<VirtualColumn> myvirtualColumns = new ArrayList<>();
final AggregatorFactory aggregatorFactory; final AggregatorFactory aggregatorFactory;
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
@ -120,7 +118,6 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
VirtualColumn virtualColumn = VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, dataType); virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, dataType);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType); dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
myvirtualColumns.add(virtualColumn);
} }
aggregatorFactory = new CardinalityAggregatorFactory( aggregatorFactory = new CardinalityAggregatorFactory(
@ -133,7 +130,6 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
} }
return Aggregation.create( return Aggregation.create(
myvirtualColumns,
Collections.singletonList(aggregatorFactory), Collections.singletonList(aggregatorFactory),
finalizeAggregations ? new HyperUniqueFinalizingPostAggregator(name, aggregatorFactory.getName()) : null finalizeAggregations ? new HyperUniqueFinalizingPostAggregator(name, aggregatorFactory.getName()) : null
); );

View File

@ -53,7 +53,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -134,19 +133,15 @@ public class ArraySqlAggregator implements SqlAggregator
break; break;
} }
} }
List<VirtualColumn> virtualColumns = new ArrayList<>();
if (arg.isDirectColumnAccess()) { if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn(); fieldName = arg.getDirectColumn();
} else { } else {
VirtualColumn vc = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, elementType); VirtualColumn vc = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, elementType);
virtualColumns.add(vc);
fieldName = vc.getOutputName(); fieldName = vc.getOutputName();
} }
if (aggregateCall.isDistinct()) { if (aggregateCall.isDistinct()) {
return Aggregation.create( return Aggregation.create(
virtualColumns,
new ExpressionLambdaAggregatorFactory( new ExpressionLambdaAggregatorFactory(
name, name,
ImmutableSet.of(fieldName), ImmutableSet.of(fieldName),
@ -163,7 +158,6 @@ public class ArraySqlAggregator implements SqlAggregator
); );
} else { } else {
return Aggregation.create( return Aggregation.create(
virtualColumns,
new ExpressionLambdaAggregatorFactory( new ExpressionLambdaAggregatorFactory(
name, name,
ImmutableSet.of(fieldName), ImmutableSet.of(fieldName),

View File

@ -31,6 +31,7 @@ import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
@ -78,37 +79,7 @@ public class AvgSqlAggregator implements SqlAggregator
return null; return null;
} }
final String fieldName;
final String expression;
final DruidExpression arg = Iterables.getOnlyElement(arguments);
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
expression = null;
} else {
fieldName = null;
expression = arg.getExpression();
}
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
final ValueType sumType;
// Use 64-bit sum regardless of the type of the AVG aggregator.
if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
sumType = ValueType.LONG;
} else {
sumType = ValueType.DOUBLE;
}
final String sumName = Calcites.makePrefixedName(name, "sum");
final String countName = Calcites.makePrefixedName(name, "count"); final String countName = Calcites.makePrefixedName(name, "count");
final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
sumType,
sumName,
fieldName,
expression,
macroTable
);
final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory( final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory(
countName, countName,
plannerContext, plannerContext,
@ -119,6 +90,38 @@ public class AvgSqlAggregator implements SqlAggregator
project project
); );
final String fieldName;
final String expression;
final DruidExpression arg = Iterables.getOnlyElement(arguments);
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
final ValueType sumType;
// Use 64-bit sum regardless of the type of the AVG aggregator.
if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
sumType = ValueType.LONG;
} else {
sumType = ValueType.DOUBLE;
}
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
expression = null;
} else {
// if the filter or anywhere else defined a virtual column for us, re-use it
VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression());
fieldName = vc != null ? vc.getOutputName() : null;
expression = vc != null ? null : arg.getExpression();
}
final String sumName = Calcites.makePrefixedName(name, "sum");
final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
sumType,
sumName,
fieldName,
expression,
macroTable
);
return Aggregation.create( return Aggregation.create(
ImmutableList.of(sum, count), ImmutableList.of(sum, count),
new ArithmeticPostAggregator( new ArithmeticPostAggregator(

View File

@ -134,7 +134,7 @@ public class CountSqlAggregator implements SqlAggregator
} else { } else {
// Not COUNT(*), not distinct // Not COUNT(*), not distinct
// COUNT(x) should count all non-null values of x. // COUNT(x) should count all non-null values of x.
return Aggregation.create(createCountAggregatorFactory( AggregatorFactory theCount = createCountAggregatorFactory(
name, name,
plannerContext, plannerContext,
rowSignature, rowSignature,
@ -142,7 +142,9 @@ public class CountSqlAggregator implements SqlAggregator
rexBuilder, rexBuilder,
aggregateCall, aggregateCall,
project project
)); );
return Aggregation.create(theCount);
} }
} }
} }

View File

@ -64,9 +64,7 @@ import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
public class EarliestLatestAnySqlAggregator implements SqlAggregator public class EarliestLatestAnySqlAggregator implements SqlAggregator
{ {
@ -209,9 +207,6 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
} }
return Aggregation.create( return Aggregation.create(
Stream.of(virtualColumnRegistry.getVirtualColumn(fieldName))
.filter(Objects::nonNull)
.collect(Collectors.toList()),
Collections.singletonList( Collections.singletonList(
aggregatorType.createAggregatorFactory( aggregatorType.createAggregatorFactory(
aggregatorName, aggregatorName,

View File

@ -636,7 +636,7 @@ public class DruidQuery
} }
for (Aggregation aggregation : grouping.getAggregations()) { for (Aggregation aggregation : grouping.getAggregations()) {
virtualColumns.addAll(aggregation.getVirtualColumns()); virtualColumns.addAll(virtualColumnRegistry.findVirtualColumns(aggregation.getRequiredColumns()));
} }
} }

View File

@ -29,7 +29,9 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
/** /**
* Provides facilities to create and re-use {@link VirtualColumn} definitions for dimensions, filters, and filtered * Provides facilities to create and re-use {@link VirtualColumn} definitions for dimensions, filters, and filtered
@ -128,6 +130,12 @@ public class VirtualColumnRegistry
return virtualColumnsByName.get(virtualColumnName); return virtualColumnsByName.get(virtualColumnName);
} }
@Nullable
public VirtualColumn getVirtualColumnByExpression(String expression)
{
return virtualColumnsByExpression.get(expression);
}
/** /**
* Get a signature representing the base signature plus all registered virtual columns. * Get a signature representing the base signature plus all registered virtual columns.
*/ */
@ -145,4 +153,15 @@ public class VirtualColumnRegistry
return builder.build(); return builder.build();
} }
/**
* Given a list of column names, find any corresponding {@link VirtualColumn} with the same name
*/
public List<VirtualColumn> findVirtualColumns(List<String> allColumns)
{
return allColumns.stream()
.filter(this::isVirtualColumnDefined)
.map(this::getVirtualColumn)
.collect(Collectors.toList());
}
} }

View File

@ -117,7 +117,6 @@ public class GroupByRules
if (doesMatch) { if (doesMatch) {
existingAggregationsWithSameFilter.add( existingAggregationsWithSameFilter.add(
Aggregation.create( Aggregation.create(
existingAggregation.getVirtualColumns(),
existingAggregation.getAggregatorFactories().stream() existingAggregation.getAggregatorFactories().stream()
.map(factory -> ((FilteredAggregatorFactory) factory).getAggregator()) .map(factory -> ((FilteredAggregatorFactory) factory).getAggregator())
.collect(Collectors.toList()), .collect(Collectors.toList()),

View File

@ -50,6 +50,7 @@ import org.apache.druid.query.QueryException;
import org.apache.druid.query.ResourceLimitExceededException; import org.apache.druid.query.ResourceLimitExceededException;
import org.apache.druid.query.TableDataSource; import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnionDataSource; import org.apache.druid.query.UnionDataSource;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
@ -112,6 +113,8 @@ import org.apache.druid.query.topn.DimensionTopNMetricSpec;
import org.apache.druid.query.topn.InvertedTopNMetricSpec; import org.apache.druid.query.topn.InvertedTopNMetricSpec;
import org.apache.druid.query.topn.NumericTopNMetricSpec; import org.apache.druid.query.topn.NumericTopNMetricSpec;
import org.apache.druid.query.topn.TopNQueryBuilder; import org.apache.druid.query.topn.TopNQueryBuilder;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.join.JoinType; import org.apache.druid.segment.join.JoinType;
@ -18926,4 +18929,76 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
expectedResults expectedResults
); );
} }
@Test
public void testCountAndAverageByConstantVirtualColumn() throws Exception
{
List<VirtualColumn> virtualColumns;
List<AggregatorFactory> aggs;
if (useDefault) {
aggs = ImmutableList.of(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
not(selector("v0", null, null))
),
new LongSumAggregatorFactory("a1:sum", null, "325323", TestExprMacroTable.INSTANCE),
new CountAggregatorFactory("a1:count")
);
virtualColumns = ImmutableList.of(
expressionVirtualColumn("v0", "'10.1'", ValueType.STRING)
);
} else {
aggs = ImmutableList.of(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
not(selector("v0", null, null))
),
new LongSumAggregatorFactory("a1:sum", "v1"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a1:count"),
not(selector("v1", null, null))
)
);
virtualColumns = ImmutableList.of(
expressionVirtualColumn("v0", "'10.1'", ValueType.STRING),
expressionVirtualColumn("v1", "325323", ValueType.LONG)
);
}
testQuery(
"SELECT dim5, COUNT(dim1), AVG(l1) FROM druid.numfoo WHERE dim1 = '10.1' AND l1 = 325323 GROUP BY dim5",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE3)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimFilter(
and(
selector("dim1", "10.1", null),
selector("l1", "325323", null)
)
)
.setGranularity(Granularities.ALL)
.setVirtualColumns(VirtualColumns.create(virtualColumns))
.setDimensions(new DefaultDimensionSpec("dim5", "_d0", ValueType.STRING))
.setAggregatorSpecs(aggs)
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(
"a1",
"quotient",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a1:sum"),
new FieldAccessPostAggregator(null, "a1:count")
)
)
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{"ab", 1L, 325323L}
)
);
}
} }