fix count and average SQL aggregators on constant virtual columns (#11208)

* fix count and average SQL aggregators on constant virtual columns

* style

* even better, why are we tracking virtual columns in aggregations at all if we have a virtual column registry

* oops missed a few

* remove unused

* this will fix it
This commit is contained in:
Clint Wylie 2021-05-10 13:41:48 -07:00 committed by GitHub
parent 15de29a2c4
commit f6662b4893
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 203 additions and 221 deletions

View File

@ -24,8 +24,8 @@ import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import java.nio.ByteBuffer;
@ -80,30 +80,24 @@ public class TDigestSketchUtils
}
public static boolean matchingAggregatorFactoryExists(
final VirtualColumnRegistry virtualColumnRegistry,
final DruidExpression input,
final Integer compression,
final Aggregation existing,
final TDigestSketchAggregatorFactory factory
)
{
// Check input for equivalence.
final boolean inputMatches;
final VirtualColumn virtualInput = existing.getVirtualColumns()
.stream()
.filter(
virtualColumn ->
virtualColumn.getOutputName()
.equals(factory.getFieldName())
)
.findFirst()
.orElse(null);
final VirtualColumn virtualInput =
virtualColumnRegistry.findVirtualColumns(factory.requiredFields())
.stream()
.findFirst()
.orElse(null);
if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess()
&& input.getDirectColumn().equals(factory.getFieldName());
inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(factory.getFieldName());
} else {
inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression()
.equals(input.getExpression());
inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression().equals(input.getExpression());
}
return inputMatches && compression == factory.getCompression();
}

View File

@ -47,7 +47,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
@ -112,9 +111,9 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
if (factory instanceof TDigestSketchAggregatorFactory) {
final TDigestSketchAggregatorFactory theFactory = (TDigestSketchAggregatorFactory) factory;
final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists(
virtualColumnRegistry,
input,
compression,
existing,
(TDigestSketchAggregatorFactory) factory
);
@ -129,8 +128,6 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
}
// No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) {
aggregatorFactory = new TDigestSketchAggregatorFactory(
aggName,
@ -143,7 +140,6 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
input,
ValueType.FLOAT
);
virtualColumns.add(virtualColumn);
aggregatorFactory = new TDigestSketchAggregatorFactory(
aggName,
virtualColumn.getOutputName(),
@ -151,10 +147,7 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
);
}
return Aggregation.create(
virtualColumns,
aggregatorFactory
);
return Aggregation.create(aggregatorFactory);
}
private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction

View File

@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
@ -123,9 +122,9 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
for (AggregatorFactory factory : existing.getAggregatorFactories()) {
if (factory instanceof TDigestSketchAggregatorFactory) {
final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists(
virtualColumnRegistry,
input,
compression,
existing,
(TDigestSketchAggregatorFactory) factory
);
@ -148,8 +147,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
}
// No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) {
aggregatorFactory = new TDigestSketchAggregatorFactory(
sketchName,
@ -162,7 +159,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
input,
ValueType.FLOAT
);
virtualColumns.add(virtualColumn);
aggregatorFactory = new TDigestSketchAggregatorFactory(
sketchName,
virtualColumn.getOutputName(),
@ -171,7 +167,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
}
return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory),
new TDigestSketchToQuantilePostAggregator(
name,

View File

@ -29,12 +29,10 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections;
import java.util.List;
public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{
@ -51,12 +49,10 @@ public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlA
protected Aggregation toAggregation(
String name,
boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
)
{
return Aggregation.create(
virtualColumns,
Collections.singletonList(aggregatorFactory),
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(
name,

View File

@ -45,7 +45,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
@ -115,7 +114,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
tgtHllType = HllSketchAggregatorFactory.DEFAULT_TGT_HLL_TYPE.name();
}
final List<VirtualColumn> virtualColumns = new ArrayList<>();
final AggregatorFactory aggregatorFactory;
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
@ -150,7 +148,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
dataType
);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
virtualColumns.add(virtualColumn);
}
aggregatorFactory = new HllSketchBuildAggregatorFactory(
@ -165,7 +162,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
return toAggregation(
name,
finalizeAggregations,
virtualColumns,
aggregatorFactory
);
}
@ -173,7 +169,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
protected abstract Aggregation toAggregation(
String name,
boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
);
}

View File

@ -28,12 +28,10 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections;
import java.util.List;
public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{
@ -50,12 +48,10 @@ public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator imp
protected Aggregation toAggregation(
String name,
boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
)
{
return Aggregation.create(
virtualColumns,
Collections.singletonList(aggregatorFactory),
null
);

View File

@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
@ -132,22 +131,16 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
// Check input for equivalence.
final boolean inputMatches;
final VirtualColumn virtualInput = existing.getVirtualColumns()
.stream()
.filter(
virtualColumn ->
virtualColumn.getOutputName()
.equals(theFactory.getFieldName())
)
.findFirst()
.orElse(null);
final VirtualColumn virtualInput =
virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
.stream()
.findFirst()
.orElse(null);
if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess()
&& input.getDirectColumn().equals(theFactory.getFieldName());
inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(theFactory.getFieldName());
} else {
inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression()
.equals(input.getExpression());
inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression().equals(input.getExpression());
}
final boolean matches = inputMatches
@ -172,8 +165,6 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
}
// No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) {
aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName,
@ -186,7 +177,6 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
input,
ValueType.FLOAT
);
virtualColumns.add(virtualColumn);
aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName,
virtualColumn.getOutputName(),
@ -195,7 +185,6 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
}
return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory),
new DoublesSketchToQuantilePostAggregator(
name,

View File

@ -47,7 +47,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public class DoublesSketchObjectSqlAggregator implements SqlAggregator
@ -110,8 +109,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
}
// No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) {
aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName,
@ -124,7 +121,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
input,
ValueType.FLOAT
);
virtualColumns.add(virtualColumn);
aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName,
virtualColumn.getOutputName(),
@ -133,7 +129,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
}
return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory),
null
);

View File

@ -29,12 +29,10 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections;
import java.util.List;
public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{
@ -51,12 +49,10 @@ public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBase
protected Aggregation toAggregation(
String name,
boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
)
{
return Aggregation.create(
virtualColumns,
Collections.singletonList(aggregatorFactory),
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(
name,

View File

@ -44,7 +44,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
@ -94,7 +93,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
sketchSize = SketchAggregatorFactory.DEFAULT_MAX_SKETCH_SIZE;
}
final List<VirtualColumn> virtualColumns = new ArrayList<>();
final AggregatorFactory aggregatorFactory;
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
@ -130,7 +128,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
dataType
);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
virtualColumns.add(virtualColumn);
}
aggregatorFactory = new SketchMergeAggregatorFactory(
@ -146,7 +143,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
return toAggregation(
name,
finalizeAggregations,
virtualColumns,
aggregatorFactory
);
}
@ -154,7 +150,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
protected abstract Aggregation toAggregation(
String name,
boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
);
}

View File

@ -28,12 +28,10 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections;
import java.util.List;
public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{
@ -50,12 +48,10 @@ public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator
protected Aggregation toAggregation(
String name,
boolean finalizeAggregations,
List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
)
{
return Aggregation.create(
virtualColumns,
Collections.singletonList(aggregatorFactory),
null
);

View File

@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public class BloomFilterSqlAggregator implements SqlAggregator
@ -115,15 +114,10 @@ public class BloomFilterSqlAggregator implements SqlAggregator
// Check input for equivalence.
final boolean inputMatches;
final VirtualColumn virtualInput =
existing.getVirtualColumns()
.stream()
.filter(virtualColumn ->
virtualColumn.getOutputName().equals(theFactory.getField().getOutputName())
)
.findFirst()
.orElse(null);
final VirtualColumn virtualInput = virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
.stream()
.findFirst()
.orElse(null);
if (virtualInput == null) {
if (input.isDirectColumnAccess()) {
inputMatches =
@ -150,7 +144,6 @@ public class BloomFilterSqlAggregator implements SqlAggregator
}
// No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
ValueType valueType = Calcites.getValueTypeForRelDataType(inputOperand.getType());
final DimensionSpec spec;
@ -173,7 +166,6 @@ public class BloomFilterSqlAggregator implements SqlAggregator
input,
inputOperand.getType()
);
virtualColumns.add(virtualColumn);
spec = new DefaultDimensionSpec(
virtualColumn.getOutputName(),
StringUtils.format("%s:%s", name, virtualColumn.getOutputName())
@ -186,10 +178,7 @@ public class BloomFilterSqlAggregator implements SqlAggregator
maxNumEntries
);
return Aggregation.create(
virtualColumns,
aggregatorFactory
);
return Aggregation.create(aggregatorFactory);
}
private static class BloomFilterSqlAggFunction extends SqlAggFunction

View File

@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
@ -188,15 +187,11 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
// Check input for equivalence.
final boolean inputMatches;
final VirtualColumn virtualInput = existing.getVirtualColumns()
.stream()
.filter(
virtualColumn ->
virtualColumn.getOutputName()
.equals(theFactory.getFieldName())
)
.findFirst()
.orElse(null);
final VirtualColumn virtualInput =
virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
.stream()
.findFirst()
.orElse(null);
if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess()
@ -224,8 +219,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
}
// No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) {
aggregatorFactory = new FixedBucketsHistogramAggregatorFactory(
histogramName,
@ -242,7 +235,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
input,
ValueType.FLOAT
);
virtualColumns.add(virtualColumn);
aggregatorFactory = new FixedBucketsHistogramAggregatorFactory(
histogramName,
virtualColumn.getOutputName(),
@ -255,7 +247,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
}
return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory),
new QuantilePostAggregator(name, histogramName, probability)
);

View File

@ -51,7 +51,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public class QuantileSqlAggregator implements SqlAggregator
@ -137,15 +136,11 @@ public class QuantileSqlAggregator implements SqlAggregator
// Check input for equivalence.
final boolean inputMatches;
final VirtualColumn virtualInput = existing.getVirtualColumns()
.stream()
.filter(
virtualColumn ->
virtualColumn.getOutputName()
.equals(theFactory.getFieldName())
)
.findFirst()
.orElse(null);
final VirtualColumn virtualInput =
virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
.stream()
.findFirst()
.orElse(null);
if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess()
@ -173,8 +168,6 @@ public class QuantileSqlAggregator implements SqlAggregator
}
// No existing match found. Create a new one.
final List<VirtualColumn> virtualColumns = new ArrayList<>();
if (input.isDirectColumnAccess()) {
if (rowSignature.getColumnType(input.getDirectColumn()).orElse(null) == ValueType.COMPLEX) {
aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory(
@ -200,7 +193,6 @@ public class QuantileSqlAggregator implements SqlAggregator
} else {
final VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, ValueType.FLOAT);
virtualColumns.add(virtualColumn);
aggregatorFactory = new ApproximateHistogramAggregatorFactory(
histogramName,
virtualColumn.getOutputName(),
@ -213,7 +205,6 @@ public class QuantileSqlAggregator implements SqlAggregator
}
return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory),
new QuantilePostAggregator(name, histogramName, probability)
);

View File

@ -48,7 +48,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
public abstract class BaseVarianceSqlAggregator implements SqlAggregator
@ -84,7 +83,6 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
final AggregatorFactory aggregatorFactory;
final RelDataType dataType = inputOperand.getType();
final ValueType inputType = Calcites.getValueTypeForRelDataType(dataType);
final List<VirtualColumn> virtualColumns = new ArrayList<>();
final DimensionSpec dimensionSpec;
final String aggName = StringUtils.format("%s:agg", name);
final SqlAggFunction func = calciteFunction();
@ -98,7 +96,6 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, dataType);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
virtualColumns.add(virtualColumn);
}
switch (inputType) {
@ -135,7 +132,6 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
}
return Aggregation.create(
virtualColumns,
ImmutableList.of(aggregatorFactory),
postAggregator
);

View File

@ -23,6 +23,8 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.query.PerSegmentQueryOptimizationContext;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.Filter;
@ -166,7 +168,10 @@ public class FilteredAggregatorFactory extends AggregatorFactory
@Override
public List<String> requiredFields()
{
return delegate.requiredFields();
return ImmutableList.copyOf(
// use a set to get rid of dupes
ImmutableSet.<String>builder().addAll(delegate.requiredFields()).addAll(filter.getRequiredColumns()).build()
);
}
@Override

View File

@ -19,11 +19,14 @@
package org.apache.druid.query.aggregation;
import com.google.common.collect.ImmutableList;
import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.query.filter.TrueDimFilter;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
public class FilteredAggregatorFactoryTest
public class FilteredAggregatorFactoryTest extends InitializedNullHandlingTest
{
@Test
public void testSimpleNaming()
@ -44,4 +47,16 @@ public class FilteredAggregatorFactoryTest
null
).getName());
}
@Test
public void testRequiredFields()
{
Assert.assertEquals(
ImmutableList.of("x", "y"),
new FilteredAggregatorFactory(
new LongSumAggregatorFactory("x", "x"),
new SelectorDimFilter("y", "wat", null)
).requiredFields()
);
}
}

View File

@ -29,7 +29,6 @@ import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
@ -44,17 +43,14 @@ import java.util.Set;
public class Aggregation
{
private final List<VirtualColumn> virtualColumns;
private final List<AggregatorFactory> aggregatorFactories;
private final PostAggregator postAggregator;
private Aggregation(
final List<VirtualColumn> virtualColumns,
final List<AggregatorFactory> aggregatorFactories,
final PostAggregator postAggregator
)
{
this.virtualColumns = Preconditions.checkNotNull(virtualColumns, "virtualColumns");
this.aggregatorFactories = Preconditions.checkNotNull(aggregatorFactories, "aggregatorFactories");
this.postAggregator = postAggregator;
@ -88,19 +84,10 @@ public class Aggregation
}
}
public static Aggregation create(final List<VirtualColumn> virtualColumns, final AggregatorFactory aggregatorFactory)
{
return new Aggregation(
virtualColumns,
ImmutableList.of(aggregatorFactory),
null
);
}
public static Aggregation create(final AggregatorFactory aggregatorFactory)
{
return new Aggregation(
ImmutableList.of(),
ImmutableList.of(aggregatorFactory),
null
);
@ -108,7 +95,7 @@ public class Aggregation
public static Aggregation create(final PostAggregator postAggregator)
{
return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
return new Aggregation(Collections.emptyList(), postAggregator);
}
public static Aggregation create(
@ -116,21 +103,19 @@ public class Aggregation
final PostAggregator postAggregator
)
{
return new Aggregation(ImmutableList.of(), aggregatorFactories, postAggregator);
return new Aggregation(aggregatorFactories, postAggregator);
}
public static Aggregation create(
final List<VirtualColumn> virtualColumns,
final List<AggregatorFactory> aggregatorFactories,
final PostAggregator postAggregator
)
public List<String> getRequiredColumns()
{
return new Aggregation(virtualColumns, aggregatorFactories, postAggregator);
}
public List<VirtualColumn> getVirtualColumns()
{
return virtualColumns;
Set<String> columns = new HashSet<>();
for (AggregatorFactory agg : aggregatorFactories) {
columns.addAll(agg.requiredFields());
}
if (postAggregator != null) {
columns.addAll(postAggregator.getDependentFields());
}
return ImmutableList.copyOf(columns);
}
public List<AggregatorFactory> getAggregatorFactories()
@ -181,21 +166,10 @@ public class Aggregation
.optimizeFilterOnly(virtualColumnRegistry.getFullRowSignature())
.getDimFilter();
Set<VirtualColumn> aggVirtualColumnsPlusFilterColumns = new HashSet<>(virtualColumns);
for (String column : baseOptimizedFilter.getRequiredColumns()) {
if (virtualColumnRegistry.isVirtualColumnDefined(column)) {
aggVirtualColumnsPlusFilterColumns.add(virtualColumnRegistry.getVirtualColumn(column));
}
}
final List<AggregatorFactory> newAggregators = new ArrayList<>();
for (AggregatorFactory agg : aggregatorFactories) {
if (agg instanceof FilteredAggregatorFactory) {
final FilteredAggregatorFactory filteredAgg = (FilteredAggregatorFactory) agg;
for (String column : filteredAgg.getFilter().getRequiredColumns()) {
if (virtualColumnRegistry.isVirtualColumnDefined(column)) {
aggVirtualColumnsPlusFilterColumns.add(virtualColumnRegistry.getVirtualColumn(column));
}
}
newAggregators.add(
new FilteredAggregatorFactory(
filteredAgg.getAggregator(),
@ -209,7 +183,7 @@ public class Aggregation
}
}
return new Aggregation(new ArrayList<>(aggVirtualColumnsPlusFilterColumns), newAggregators, postAggregator);
return new Aggregation(newAggregators, postAggregator);
}
@Override
@ -222,23 +196,21 @@ public class Aggregation
return false;
}
final Aggregation that = (Aggregation) o;
return Objects.equals(virtualColumns, that.virtualColumns) &&
Objects.equals(aggregatorFactories, that.aggregatorFactories) &&
return Objects.equals(aggregatorFactories, that.aggregatorFactories) &&
Objects.equals(postAggregator, that.postAggregator);
}
@Override
public int hashCode()
{
return Objects.hash(virtualColumns, aggregatorFactories, postAggregator);
return Objects.hash(aggregatorFactories, postAggregator);
}
@Override
public String toString()
{
return "Aggregation{" +
"virtualColumns=" + virtualColumns +
", aggregatorFactories=" + aggregatorFactories +
"aggregatorFactories=" + aggregatorFactories +
", postAggregator=" + postAggregator +
'}';
}

View File

@ -52,7 +52,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@ -94,7 +93,6 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
return null;
}
final List<VirtualColumn> myvirtualColumns = new ArrayList<>();
final AggregatorFactory aggregatorFactory;
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
@ -120,7 +118,6 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, dataType);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
myvirtualColumns.add(virtualColumn);
}
aggregatorFactory = new CardinalityAggregatorFactory(
@ -133,7 +130,6 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
}
return Aggregation.create(
myvirtualColumns,
Collections.singletonList(aggregatorFactory),
finalizeAggregations ? new HyperUniqueFinalizingPostAggregator(name, aggregatorFactory.getName()) : null
);

View File

@ -53,7 +53,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@ -134,19 +133,15 @@ public class ArraySqlAggregator implements SqlAggregator
break;
}
}
List<VirtualColumn> virtualColumns = new ArrayList<>();
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
} else {
VirtualColumn vc = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, elementType);
virtualColumns.add(vc);
fieldName = vc.getOutputName();
}
if (aggregateCall.isDistinct()) {
return Aggregation.create(
virtualColumns,
new ExpressionLambdaAggregatorFactory(
name,
ImmutableSet.of(fieldName),
@ -163,7 +158,6 @@ public class ArraySqlAggregator implements SqlAggregator
);
} else {
return Aggregation.create(
virtualColumns,
new ExpressionLambdaAggregatorFactory(
name,
ImmutableSet.of(fieldName),

View File

@ -31,6 +31,7 @@ import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
@ -78,37 +79,7 @@ public class AvgSqlAggregator implements SqlAggregator
return null;
}
final String fieldName;
final String expression;
final DruidExpression arg = Iterables.getOnlyElement(arguments);
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
expression = null;
} else {
fieldName = null;
expression = arg.getExpression();
}
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
final ValueType sumType;
// Use 64-bit sum regardless of the type of the AVG aggregator.
if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
sumType = ValueType.LONG;
} else {
sumType = ValueType.DOUBLE;
}
final String sumName = Calcites.makePrefixedName(name, "sum");
final String countName = Calcites.makePrefixedName(name, "count");
final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
sumType,
sumName,
fieldName,
expression,
macroTable
);
final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory(
countName,
plannerContext,
@ -119,6 +90,38 @@ public class AvgSqlAggregator implements SqlAggregator
project
);
final String fieldName;
final String expression;
final DruidExpression arg = Iterables.getOnlyElement(arguments);
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
final ValueType sumType;
// Use 64-bit sum regardless of the type of the AVG aggregator.
if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
sumType = ValueType.LONG;
} else {
sumType = ValueType.DOUBLE;
}
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
expression = null;
} else {
// if the filter or anywhere else defined a virtual column for us, re-use it
VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression());
fieldName = vc != null ? vc.getOutputName() : null;
expression = vc != null ? null : arg.getExpression();
}
final String sumName = Calcites.makePrefixedName(name, "sum");
final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
sumType,
sumName,
fieldName,
expression,
macroTable
);
return Aggregation.create(
ImmutableList.of(sum, count),
new ArithmeticPostAggregator(

View File

@ -134,7 +134,7 @@ public class CountSqlAggregator implements SqlAggregator
} else {
// Not COUNT(*), not distinct
// COUNT(x) should count all non-null values of x.
return Aggregation.create(createCountAggregatorFactory(
AggregatorFactory theCount = createCountAggregatorFactory(
name,
plannerContext,
rowSignature,
@ -142,7 +142,9 @@ public class CountSqlAggregator implements SqlAggregator
rexBuilder,
aggregateCall,
project
));
);
return Aggregation.create(theCount);
}
}
}

View File

@ -64,9 +64,7 @@ import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class EarliestLatestAnySqlAggregator implements SqlAggregator
{
@ -209,9 +207,6 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
}
return Aggregation.create(
Stream.of(virtualColumnRegistry.getVirtualColumn(fieldName))
.filter(Objects::nonNull)
.collect(Collectors.toList()),
Collections.singletonList(
aggregatorType.createAggregatorFactory(
aggregatorName,

View File

@ -636,7 +636,7 @@ public class DruidQuery
}
for (Aggregation aggregation : grouping.getAggregations()) {
virtualColumns.addAll(aggregation.getVirtualColumns());
virtualColumns.addAll(virtualColumnRegistry.findVirtualColumns(aggregation.getRequiredColumns()));
}
}

View File

@ -29,7 +29,9 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
import javax.annotation.Nullable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Provides facilities to create and re-use {@link VirtualColumn} definitions for dimensions, filters, and filtered
@ -128,6 +130,12 @@ public class VirtualColumnRegistry
return virtualColumnsByName.get(virtualColumnName);
}
@Nullable
public VirtualColumn getVirtualColumnByExpression(String expression)
{
return virtualColumnsByExpression.get(expression);
}
/**
* Get a signature representing the base signature plus all registered virtual columns.
*/
@ -145,4 +153,15 @@ public class VirtualColumnRegistry
return builder.build();
}
/**
* Given a list of column names, find any corresponding {@link VirtualColumn} with the same name
*/
public List<VirtualColumn> findVirtualColumns(List<String> allColumns)
{
return allColumns.stream()
.filter(this::isVirtualColumnDefined)
.map(this::getVirtualColumn)
.collect(Collectors.toList());
}
}

View File

@ -117,7 +117,6 @@ public class GroupByRules
if (doesMatch) {
existingAggregationsWithSameFilter.add(
Aggregation.create(
existingAggregation.getVirtualColumns(),
existingAggregation.getAggregatorFactories().stream()
.map(factory -> ((FilteredAggregatorFactory) factory).getAggregator())
.collect(Collectors.toList()),

View File

@ -50,6 +50,7 @@ import org.apache.druid.query.QueryException;
import org.apache.druid.query.ResourceLimitExceededException;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnionDataSource;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
@ -112,6 +113,8 @@ import org.apache.druid.query.topn.DimensionTopNMetricSpec;
import org.apache.druid.query.topn.InvertedTopNMetricSpec;
import org.apache.druid.query.topn.NumericTopNMetricSpec;
import org.apache.druid.query.topn.TopNQueryBuilder;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.join.JoinType;
@ -18926,4 +18929,76 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
expectedResults
);
}
@Test
public void testCountAndAverageByConstantVirtualColumn() throws Exception
{
List<VirtualColumn> virtualColumns;
List<AggregatorFactory> aggs;
if (useDefault) {
aggs = ImmutableList.of(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
not(selector("v0", null, null))
),
new LongSumAggregatorFactory("a1:sum", null, "325323", TestExprMacroTable.INSTANCE),
new CountAggregatorFactory("a1:count")
);
virtualColumns = ImmutableList.of(
expressionVirtualColumn("v0", "'10.1'", ValueType.STRING)
);
} else {
aggs = ImmutableList.of(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
not(selector("v0", null, null))
),
new LongSumAggregatorFactory("a1:sum", "v1"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a1:count"),
not(selector("v1", null, null))
)
);
virtualColumns = ImmutableList.of(
expressionVirtualColumn("v0", "'10.1'", ValueType.STRING),
expressionVirtualColumn("v1", "325323", ValueType.LONG)
);
}
testQuery(
"SELECT dim5, COUNT(dim1), AVG(l1) FROM druid.numfoo WHERE dim1 = '10.1' AND l1 = 325323 GROUP BY dim5",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE3)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimFilter(
and(
selector("dim1", "10.1", null),
selector("l1", "325323", null)
)
)
.setGranularity(Granularities.ALL)
.setVirtualColumns(VirtualColumns.create(virtualColumns))
.setDimensions(new DefaultDimensionSpec("dim5", "_d0", ValueType.STRING))
.setAggregatorSpecs(aggs)
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(
"a1",
"quotient",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a1:sum"),
new FieldAccessPostAggregator(null, "a1:count")
)
)
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{"ab", 1L, 325323L}
)
);
}
}