mirror of https://github.com/apache/druid.git
fix count and average SQL aggregators on constant virtual columns (#11208)
* fix count and average SQL aggregators on constant virtual columns * style * even better, why are we tracking virtual columns in aggregations at all if we have a virtual column registry * oops missed a few * remove unused * this will fix it
This commit is contained in:
parent
15de29a2c4
commit
f6662b4893
|
@ -24,8 +24,8 @@ import org.apache.druid.java.util.common.IAE;
|
|||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.segment.VirtualColumn;
|
||||
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
|
@ -80,30 +80,24 @@ public class TDigestSketchUtils
|
|||
}
|
||||
|
||||
public static boolean matchingAggregatorFactoryExists(
|
||||
final VirtualColumnRegistry virtualColumnRegistry,
|
||||
final DruidExpression input,
|
||||
final Integer compression,
|
||||
final Aggregation existing,
|
||||
final TDigestSketchAggregatorFactory factory
|
||||
)
|
||||
{
|
||||
// Check input for equivalence.
|
||||
final boolean inputMatches;
|
||||
final VirtualColumn virtualInput = existing.getVirtualColumns()
|
||||
.stream()
|
||||
.filter(
|
||||
virtualColumn ->
|
||||
virtualColumn.getOutputName()
|
||||
.equals(factory.getFieldName())
|
||||
)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
final VirtualColumn virtualInput =
|
||||
virtualColumnRegistry.findVirtualColumns(factory.requiredFields())
|
||||
.stream()
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
|
||||
if (virtualInput == null) {
|
||||
inputMatches = input.isDirectColumnAccess()
|
||||
&& input.getDirectColumn().equals(factory.getFieldName());
|
||||
inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(factory.getFieldName());
|
||||
} else {
|
||||
inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression()
|
||||
.equals(input.getExpression());
|
||||
inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression().equals(input.getExpression());
|
||||
}
|
||||
return inputMatches && compression == factory.getCompression();
|
||||
}
|
||||
|
|
|
@ -47,7 +47,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
|
||||
|
@ -112,9 +111,9 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
|
|||
if (factory instanceof TDigestSketchAggregatorFactory) {
|
||||
final TDigestSketchAggregatorFactory theFactory = (TDigestSketchAggregatorFactory) factory;
|
||||
final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists(
|
||||
virtualColumnRegistry,
|
||||
input,
|
||||
compression,
|
||||
existing,
|
||||
(TDigestSketchAggregatorFactory) factory
|
||||
);
|
||||
|
||||
|
@ -129,8 +128,6 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
// No existing match found. Create a new one.
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
|
||||
if (input.isDirectColumnAccess()) {
|
||||
aggregatorFactory = new TDigestSketchAggregatorFactory(
|
||||
aggName,
|
||||
|
@ -143,7 +140,6 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
|
|||
input,
|
||||
ValueType.FLOAT
|
||||
);
|
||||
virtualColumns.add(virtualColumn);
|
||||
aggregatorFactory = new TDigestSketchAggregatorFactory(
|
||||
aggName,
|
||||
virtualColumn.getOutputName(),
|
||||
|
@ -151,10 +147,7 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
|
|||
);
|
||||
}
|
||||
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
aggregatorFactory
|
||||
);
|
||||
return Aggregation.create(aggregatorFactory);
|
||||
}
|
||||
|
||||
private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction
|
||||
|
|
|
@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
|
||||
|
@ -123,9 +122,9 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
|
|||
for (AggregatorFactory factory : existing.getAggregatorFactories()) {
|
||||
if (factory instanceof TDigestSketchAggregatorFactory) {
|
||||
final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists(
|
||||
virtualColumnRegistry,
|
||||
input,
|
||||
compression,
|
||||
existing,
|
||||
(TDigestSketchAggregatorFactory) factory
|
||||
);
|
||||
|
||||
|
@ -148,8 +147,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
// No existing match found. Create a new one.
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
|
||||
if (input.isDirectColumnAccess()) {
|
||||
aggregatorFactory = new TDigestSketchAggregatorFactory(
|
||||
sketchName,
|
||||
|
@ -162,7 +159,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
|
|||
input,
|
||||
ValueType.FLOAT
|
||||
);
|
||||
virtualColumns.add(virtualColumn);
|
||||
aggregatorFactory = new TDigestSketchAggregatorFactory(
|
||||
sketchName,
|
||||
virtualColumn.getOutputName(),
|
||||
|
@ -171,7 +167,6 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
ImmutableList.of(aggregatorFactory),
|
||||
new TDigestSketchToQuantilePostAggregator(
|
||||
name,
|
||||
|
|
|
@ -29,12 +29,10 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
|
|||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
|
||||
import org.apache.druid.segment.VirtualColumn;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
|
||||
{
|
||||
|
@ -51,12 +49,10 @@ public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlA
|
|||
protected Aggregation toAggregation(
|
||||
String name,
|
||||
boolean finalizeAggregations,
|
||||
List<VirtualColumn> virtualColumns,
|
||||
AggregatorFactory aggregatorFactory
|
||||
)
|
||||
{
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
Collections.singletonList(aggregatorFactory),
|
||||
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(
|
||||
name,
|
||||
|
|
|
@ -45,7 +45,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
|
||||
|
@ -115,7 +114,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
|
|||
tgtHllType = HllSketchAggregatorFactory.DEFAULT_TGT_HLL_TYPE.name();
|
||||
}
|
||||
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
final AggregatorFactory aggregatorFactory;
|
||||
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
|
||||
|
||||
|
@ -150,7 +148,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
|
|||
dataType
|
||||
);
|
||||
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
|
||||
virtualColumns.add(virtualColumn);
|
||||
}
|
||||
|
||||
aggregatorFactory = new HllSketchBuildAggregatorFactory(
|
||||
|
@ -165,7 +162,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
|
|||
return toAggregation(
|
||||
name,
|
||||
finalizeAggregations,
|
||||
virtualColumns,
|
||||
aggregatorFactory
|
||||
);
|
||||
}
|
||||
|
@ -173,7 +169,6 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
|
|||
protected abstract Aggregation toAggregation(
|
||||
String name,
|
||||
boolean finalizeAggregations,
|
||||
List<VirtualColumn> virtualColumns,
|
||||
AggregatorFactory aggregatorFactory
|
||||
);
|
||||
}
|
||||
|
|
|
@ -28,12 +28,10 @@ import org.apache.calcite.sql.type.ReturnTypes;
|
|||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.segment.VirtualColumn;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
|
||||
{
|
||||
|
@ -50,12 +48,10 @@ public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator imp
|
|||
protected Aggregation toAggregation(
|
||||
String name,
|
||||
boolean finalizeAggregations,
|
||||
List<VirtualColumn> virtualColumns,
|
||||
AggregatorFactory aggregatorFactory
|
||||
)
|
||||
{
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
Collections.singletonList(aggregatorFactory),
|
||||
null
|
||||
);
|
||||
|
|
|
@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
|
||||
|
@ -132,22 +131,16 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
|
|||
|
||||
// Check input for equivalence.
|
||||
final boolean inputMatches;
|
||||
final VirtualColumn virtualInput = existing.getVirtualColumns()
|
||||
.stream()
|
||||
.filter(
|
||||
virtualColumn ->
|
||||
virtualColumn.getOutputName()
|
||||
.equals(theFactory.getFieldName())
|
||||
)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
final VirtualColumn virtualInput =
|
||||
virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
|
||||
.stream()
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
|
||||
if (virtualInput == null) {
|
||||
inputMatches = input.isDirectColumnAccess()
|
||||
&& input.getDirectColumn().equals(theFactory.getFieldName());
|
||||
inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(theFactory.getFieldName());
|
||||
} else {
|
||||
inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression()
|
||||
.equals(input.getExpression());
|
||||
inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression().equals(input.getExpression());
|
||||
}
|
||||
|
||||
final boolean matches = inputMatches
|
||||
|
@ -172,8 +165,6 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
// No existing match found. Create a new one.
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
|
||||
if (input.isDirectColumnAccess()) {
|
||||
aggregatorFactory = new DoublesSketchAggregatorFactory(
|
||||
histogramName,
|
||||
|
@ -186,7 +177,6 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
|
|||
input,
|
||||
ValueType.FLOAT
|
||||
);
|
||||
virtualColumns.add(virtualColumn);
|
||||
aggregatorFactory = new DoublesSketchAggregatorFactory(
|
||||
histogramName,
|
||||
virtualColumn.getOutputName(),
|
||||
|
@ -195,7 +185,6 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
ImmutableList.of(aggregatorFactory),
|
||||
new DoublesSketchToQuantilePostAggregator(
|
||||
name,
|
||||
|
|
|
@ -47,7 +47,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class DoublesSketchObjectSqlAggregator implements SqlAggregator
|
||||
|
@ -110,8 +109,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
// No existing match found. Create a new one.
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
|
||||
if (input.isDirectColumnAccess()) {
|
||||
aggregatorFactory = new DoublesSketchAggregatorFactory(
|
||||
histogramName,
|
||||
|
@ -124,7 +121,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
|
|||
input,
|
||||
ValueType.FLOAT
|
||||
);
|
||||
virtualColumns.add(virtualColumn);
|
||||
aggregatorFactory = new DoublesSketchAggregatorFactory(
|
||||
histogramName,
|
||||
virtualColumn.getOutputName(),
|
||||
|
@ -133,7 +129,6 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
ImmutableList.of(aggregatorFactory),
|
||||
null
|
||||
);
|
||||
|
|
|
@ -29,12 +29,10 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
|
|||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
|
||||
import org.apache.druid.segment.VirtualColumn;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
|
||||
{
|
||||
|
@ -51,12 +49,10 @@ public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBase
|
|||
protected Aggregation toAggregation(
|
||||
String name,
|
||||
boolean finalizeAggregations,
|
||||
List<VirtualColumn> virtualColumns,
|
||||
AggregatorFactory aggregatorFactory
|
||||
)
|
||||
{
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
Collections.singletonList(aggregatorFactory),
|
||||
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(
|
||||
name,
|
||||
|
|
|
@ -44,7 +44,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
|
||||
|
@ -94,7 +93,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
|
|||
sketchSize = SketchAggregatorFactory.DEFAULT_MAX_SKETCH_SIZE;
|
||||
}
|
||||
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
final AggregatorFactory aggregatorFactory;
|
||||
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
|
||||
|
||||
|
@ -130,7 +128,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
|
|||
dataType
|
||||
);
|
||||
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
|
||||
virtualColumns.add(virtualColumn);
|
||||
}
|
||||
|
||||
aggregatorFactory = new SketchMergeAggregatorFactory(
|
||||
|
@ -146,7 +143,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
|
|||
return toAggregation(
|
||||
name,
|
||||
finalizeAggregations,
|
||||
virtualColumns,
|
||||
aggregatorFactory
|
||||
);
|
||||
}
|
||||
|
@ -154,7 +150,6 @@ public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
|
|||
protected abstract Aggregation toAggregation(
|
||||
String name,
|
||||
boolean finalizeAggregations,
|
||||
List<VirtualColumn> virtualColumns,
|
||||
AggregatorFactory aggregatorFactory
|
||||
);
|
||||
}
|
||||
|
|
|
@ -28,12 +28,10 @@ import org.apache.calcite.sql.type.ReturnTypes;
|
|||
import org.apache.calcite.sql.type.SqlTypeFamily;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.segment.VirtualColumn;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
|
||||
{
|
||||
|
@ -50,12 +48,10 @@ public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator
|
|||
protected Aggregation toAggregation(
|
||||
String name,
|
||||
boolean finalizeAggregations,
|
||||
List<VirtualColumn> virtualColumns,
|
||||
AggregatorFactory aggregatorFactory
|
||||
)
|
||||
{
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
Collections.singletonList(aggregatorFactory),
|
||||
null
|
||||
);
|
||||
|
|
|
@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class BloomFilterSqlAggregator implements SqlAggregator
|
||||
|
@ -115,15 +114,10 @@ public class BloomFilterSqlAggregator implements SqlAggregator
|
|||
|
||||
// Check input for equivalence.
|
||||
final boolean inputMatches;
|
||||
final VirtualColumn virtualInput =
|
||||
existing.getVirtualColumns()
|
||||
.stream()
|
||||
.filter(virtualColumn ->
|
||||
virtualColumn.getOutputName().equals(theFactory.getField().getOutputName())
|
||||
)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
|
||||
final VirtualColumn virtualInput = virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
|
||||
.stream()
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
if (virtualInput == null) {
|
||||
if (input.isDirectColumnAccess()) {
|
||||
inputMatches =
|
||||
|
@ -150,7 +144,6 @@ public class BloomFilterSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
// No existing match found. Create a new one.
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
|
||||
ValueType valueType = Calcites.getValueTypeForRelDataType(inputOperand.getType());
|
||||
final DimensionSpec spec;
|
||||
|
@ -173,7 +166,6 @@ public class BloomFilterSqlAggregator implements SqlAggregator
|
|||
input,
|
||||
inputOperand.getType()
|
||||
);
|
||||
virtualColumns.add(virtualColumn);
|
||||
spec = new DefaultDimensionSpec(
|
||||
virtualColumn.getOutputName(),
|
||||
StringUtils.format("%s:%s", name, virtualColumn.getOutputName())
|
||||
|
@ -186,10 +178,7 @@ public class BloomFilterSqlAggregator implements SqlAggregator
|
|||
maxNumEntries
|
||||
);
|
||||
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
aggregatorFactory
|
||||
);
|
||||
return Aggregation.create(aggregatorFactory);
|
||||
}
|
||||
|
||||
private static class BloomFilterSqlAggFunction extends SqlAggFunction
|
||||
|
|
|
@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
|
||||
|
@ -188,15 +187,11 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
|
|||
|
||||
// Check input for equivalence.
|
||||
final boolean inputMatches;
|
||||
final VirtualColumn virtualInput = existing.getVirtualColumns()
|
||||
.stream()
|
||||
.filter(
|
||||
virtualColumn ->
|
||||
virtualColumn.getOutputName()
|
||||
.equals(theFactory.getFieldName())
|
||||
)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
final VirtualColumn virtualInput =
|
||||
virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
|
||||
.stream()
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
|
||||
if (virtualInput == null) {
|
||||
inputMatches = input.isDirectColumnAccess()
|
||||
|
@ -224,8 +219,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
// No existing match found. Create a new one.
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
|
||||
if (input.isDirectColumnAccess()) {
|
||||
aggregatorFactory = new FixedBucketsHistogramAggregatorFactory(
|
||||
histogramName,
|
||||
|
@ -242,7 +235,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
|
|||
input,
|
||||
ValueType.FLOAT
|
||||
);
|
||||
virtualColumns.add(virtualColumn);
|
||||
aggregatorFactory = new FixedBucketsHistogramAggregatorFactory(
|
||||
histogramName,
|
||||
virtualColumn.getOutputName(),
|
||||
|
@ -255,7 +247,6 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
ImmutableList.of(aggregatorFactory),
|
||||
new QuantilePostAggregator(name, histogramName, probability)
|
||||
);
|
||||
|
|
|
@ -51,7 +51,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class QuantileSqlAggregator implements SqlAggregator
|
||||
|
@ -137,15 +136,11 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
|
||||
// Check input for equivalence.
|
||||
final boolean inputMatches;
|
||||
final VirtualColumn virtualInput = existing.getVirtualColumns()
|
||||
.stream()
|
||||
.filter(
|
||||
virtualColumn ->
|
||||
virtualColumn.getOutputName()
|
||||
.equals(theFactory.getFieldName())
|
||||
)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
final VirtualColumn virtualInput =
|
||||
virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
|
||||
.stream()
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
|
||||
if (virtualInput == null) {
|
||||
inputMatches = input.isDirectColumnAccess()
|
||||
|
@ -173,8 +168,6 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
// No existing match found. Create a new one.
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
|
||||
if (input.isDirectColumnAccess()) {
|
||||
if (rowSignature.getColumnType(input.getDirectColumn()).orElse(null) == ValueType.COMPLEX) {
|
||||
aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory(
|
||||
|
@ -200,7 +193,6 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
} else {
|
||||
final VirtualColumn virtualColumn =
|
||||
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, ValueType.FLOAT);
|
||||
virtualColumns.add(virtualColumn);
|
||||
aggregatorFactory = new ApproximateHistogramAggregatorFactory(
|
||||
histogramName,
|
||||
virtualColumn.getOutputName(),
|
||||
|
@ -213,7 +205,6 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
ImmutableList.of(aggregatorFactory),
|
||||
new QuantilePostAggregator(name, histogramName, probability)
|
||||
);
|
||||
|
|
|
@ -48,7 +48,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
||||
|
@ -84,7 +83,6 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
final AggregatorFactory aggregatorFactory;
|
||||
final RelDataType dataType = inputOperand.getType();
|
||||
final ValueType inputType = Calcites.getValueTypeForRelDataType(dataType);
|
||||
final List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
final DimensionSpec dimensionSpec;
|
||||
final String aggName = StringUtils.format("%s:agg", name);
|
||||
final SqlAggFunction func = calciteFunction();
|
||||
|
@ -98,7 +96,6 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
VirtualColumn virtualColumn =
|
||||
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, dataType);
|
||||
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
|
||||
virtualColumns.add(virtualColumn);
|
||||
}
|
||||
|
||||
switch (inputType) {
|
||||
|
@ -135,7 +132,6 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
ImmutableList.of(aggregatorFactory),
|
||||
postAggregator
|
||||
);
|
||||
|
|
|
@ -23,6 +23,8 @@ import com.fasterxml.jackson.annotation.JsonCreator;
|
|||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import org.apache.druid.query.PerSegmentQueryOptimizationContext;
|
||||
import org.apache.druid.query.filter.DimFilter;
|
||||
import org.apache.druid.query.filter.Filter;
|
||||
|
@ -166,7 +168,10 @@ public class FilteredAggregatorFactory extends AggregatorFactory
|
|||
@Override
|
||||
public List<String> requiredFields()
|
||||
{
|
||||
return delegate.requiredFields();
|
||||
return ImmutableList.copyOf(
|
||||
// use a set to get rid of dupes
|
||||
ImmutableSet.<String>builder().addAll(delegate.requiredFields()).addAll(filter.getRequiredColumns()).build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -19,11 +19,14 @@
|
|||
|
||||
package org.apache.druid.query.aggregation;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.apache.druid.query.filter.SelectorDimFilter;
|
||||
import org.apache.druid.query.filter.TrueDimFilter;
|
||||
import org.apache.druid.testing.InitializedNullHandlingTest;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class FilteredAggregatorFactoryTest
|
||||
public class FilteredAggregatorFactoryTest extends InitializedNullHandlingTest
|
||||
{
|
||||
@Test
|
||||
public void testSimpleNaming()
|
||||
|
@ -44,4 +47,16 @@ public class FilteredAggregatorFactoryTest
|
|||
null
|
||||
).getName());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRequiredFields()
|
||||
{
|
||||
Assert.assertEquals(
|
||||
ImmutableList.of("x", "y"),
|
||||
new FilteredAggregatorFactory(
|
||||
new LongSumAggregatorFactory("x", "x"),
|
||||
new SelectorDimFilter("y", "wat", null)
|
||||
).requiredFields()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,7 +29,6 @@ import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
|
|||
import org.apache.druid.query.aggregation.PostAggregator;
|
||||
import org.apache.druid.query.filter.AndDimFilter;
|
||||
import org.apache.druid.query.filter.DimFilter;
|
||||
import org.apache.druid.segment.VirtualColumn;
|
||||
import org.apache.druid.segment.column.RowSignature;
|
||||
import org.apache.druid.sql.calcite.filtration.Filtration;
|
||||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
@ -44,17 +43,14 @@ import java.util.Set;
|
|||
|
||||
public class Aggregation
|
||||
{
|
||||
private final List<VirtualColumn> virtualColumns;
|
||||
private final List<AggregatorFactory> aggregatorFactories;
|
||||
private final PostAggregator postAggregator;
|
||||
|
||||
private Aggregation(
|
||||
final List<VirtualColumn> virtualColumns,
|
||||
final List<AggregatorFactory> aggregatorFactories,
|
||||
final PostAggregator postAggregator
|
||||
)
|
||||
{
|
||||
this.virtualColumns = Preconditions.checkNotNull(virtualColumns, "virtualColumns");
|
||||
this.aggregatorFactories = Preconditions.checkNotNull(aggregatorFactories, "aggregatorFactories");
|
||||
this.postAggregator = postAggregator;
|
||||
|
||||
|
@ -88,19 +84,10 @@ public class Aggregation
|
|||
}
|
||||
}
|
||||
|
||||
public static Aggregation create(final List<VirtualColumn> virtualColumns, final AggregatorFactory aggregatorFactory)
|
||||
{
|
||||
return new Aggregation(
|
||||
virtualColumns,
|
||||
ImmutableList.of(aggregatorFactory),
|
||||
null
|
||||
);
|
||||
}
|
||||
|
||||
public static Aggregation create(final AggregatorFactory aggregatorFactory)
|
||||
{
|
||||
return new Aggregation(
|
||||
ImmutableList.of(),
|
||||
ImmutableList.of(aggregatorFactory),
|
||||
null
|
||||
);
|
||||
|
@ -108,7 +95,7 @@ public class Aggregation
|
|||
|
||||
public static Aggregation create(final PostAggregator postAggregator)
|
||||
{
|
||||
return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
|
||||
return new Aggregation(Collections.emptyList(), postAggregator);
|
||||
}
|
||||
|
||||
public static Aggregation create(
|
||||
|
@ -116,21 +103,19 @@ public class Aggregation
|
|||
final PostAggregator postAggregator
|
||||
)
|
||||
{
|
||||
return new Aggregation(ImmutableList.of(), aggregatorFactories, postAggregator);
|
||||
return new Aggregation(aggregatorFactories, postAggregator);
|
||||
}
|
||||
|
||||
public static Aggregation create(
|
||||
final List<VirtualColumn> virtualColumns,
|
||||
final List<AggregatorFactory> aggregatorFactories,
|
||||
final PostAggregator postAggregator
|
||||
)
|
||||
public List<String> getRequiredColumns()
|
||||
{
|
||||
return new Aggregation(virtualColumns, aggregatorFactories, postAggregator);
|
||||
}
|
||||
|
||||
public List<VirtualColumn> getVirtualColumns()
|
||||
{
|
||||
return virtualColumns;
|
||||
Set<String> columns = new HashSet<>();
|
||||
for (AggregatorFactory agg : aggregatorFactories) {
|
||||
columns.addAll(agg.requiredFields());
|
||||
}
|
||||
if (postAggregator != null) {
|
||||
columns.addAll(postAggregator.getDependentFields());
|
||||
}
|
||||
return ImmutableList.copyOf(columns);
|
||||
}
|
||||
|
||||
public List<AggregatorFactory> getAggregatorFactories()
|
||||
|
@ -181,21 +166,10 @@ public class Aggregation
|
|||
.optimizeFilterOnly(virtualColumnRegistry.getFullRowSignature())
|
||||
.getDimFilter();
|
||||
|
||||
Set<VirtualColumn> aggVirtualColumnsPlusFilterColumns = new HashSet<>(virtualColumns);
|
||||
for (String column : baseOptimizedFilter.getRequiredColumns()) {
|
||||
if (virtualColumnRegistry.isVirtualColumnDefined(column)) {
|
||||
aggVirtualColumnsPlusFilterColumns.add(virtualColumnRegistry.getVirtualColumn(column));
|
||||
}
|
||||
}
|
||||
final List<AggregatorFactory> newAggregators = new ArrayList<>();
|
||||
for (AggregatorFactory agg : aggregatorFactories) {
|
||||
if (agg instanceof FilteredAggregatorFactory) {
|
||||
final FilteredAggregatorFactory filteredAgg = (FilteredAggregatorFactory) agg;
|
||||
for (String column : filteredAgg.getFilter().getRequiredColumns()) {
|
||||
if (virtualColumnRegistry.isVirtualColumnDefined(column)) {
|
||||
aggVirtualColumnsPlusFilterColumns.add(virtualColumnRegistry.getVirtualColumn(column));
|
||||
}
|
||||
}
|
||||
newAggregators.add(
|
||||
new FilteredAggregatorFactory(
|
||||
filteredAgg.getAggregator(),
|
||||
|
@ -209,7 +183,7 @@ public class Aggregation
|
|||
}
|
||||
}
|
||||
|
||||
return new Aggregation(new ArrayList<>(aggVirtualColumnsPlusFilterColumns), newAggregators, postAggregator);
|
||||
return new Aggregation(newAggregators, postAggregator);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -222,23 +196,21 @@ public class Aggregation
|
|||
return false;
|
||||
}
|
||||
final Aggregation that = (Aggregation) o;
|
||||
return Objects.equals(virtualColumns, that.virtualColumns) &&
|
||||
Objects.equals(aggregatorFactories, that.aggregatorFactories) &&
|
||||
return Objects.equals(aggregatorFactories, that.aggregatorFactories) &&
|
||||
Objects.equals(postAggregator, that.postAggregator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode()
|
||||
{
|
||||
return Objects.hash(virtualColumns, aggregatorFactories, postAggregator);
|
||||
return Objects.hash(aggregatorFactories, postAggregator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString()
|
||||
{
|
||||
return "Aggregation{" +
|
||||
"virtualColumns=" + virtualColumns +
|
||||
", aggregatorFactories=" + aggregatorFactories +
|
||||
"aggregatorFactories=" + aggregatorFactories +
|
||||
", postAggregator=" + postAggregator +
|
||||
'}';
|
||||
}
|
||||
|
|
|
@ -52,7 +52,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
|
@ -94,7 +93,6 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
|
|||
return null;
|
||||
}
|
||||
|
||||
final List<VirtualColumn> myvirtualColumns = new ArrayList<>();
|
||||
final AggregatorFactory aggregatorFactory;
|
||||
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
|
||||
|
||||
|
@ -120,7 +118,6 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
|
|||
VirtualColumn virtualColumn =
|
||||
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, dataType);
|
||||
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
|
||||
myvirtualColumns.add(virtualColumn);
|
||||
}
|
||||
|
||||
aggregatorFactory = new CardinalityAggregatorFactory(
|
||||
|
@ -133,7 +130,6 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
return Aggregation.create(
|
||||
myvirtualColumns,
|
||||
Collections.singletonList(aggregatorFactory),
|
||||
finalizeAggregations ? new HyperUniqueFinalizingPostAggregator(name, aggregatorFactory.getName()) : null
|
||||
);
|
||||
|
|
|
@ -53,7 +53,6 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -134,19 +133,15 @@ public class ArraySqlAggregator implements SqlAggregator
|
|||
break;
|
||||
}
|
||||
}
|
||||
List<VirtualColumn> virtualColumns = new ArrayList<>();
|
||||
|
||||
if (arg.isDirectColumnAccess()) {
|
||||
fieldName = arg.getDirectColumn();
|
||||
} else {
|
||||
VirtualColumn vc = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, elementType);
|
||||
virtualColumns.add(vc);
|
||||
fieldName = vc.getOutputName();
|
||||
}
|
||||
|
||||
if (aggregateCall.isDistinct()) {
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
name,
|
||||
ImmutableSet.of(fieldName),
|
||||
|
@ -163,7 +158,6 @@ public class ArraySqlAggregator implements SqlAggregator
|
|||
);
|
||||
} else {
|
||||
return Aggregation.create(
|
||||
virtualColumns,
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
name,
|
||||
ImmutableSet.of(fieldName),
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.druid.math.expr.ExprMacroTable;
|
|||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
|
||||
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
|
||||
import org.apache.druid.segment.VirtualColumn;
|
||||
import org.apache.druid.segment.column.RowSignature;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
|
@ -78,37 +79,7 @@ public class AvgSqlAggregator implements SqlAggregator
|
|||
return null;
|
||||
}
|
||||
|
||||
final String fieldName;
|
||||
final String expression;
|
||||
final DruidExpression arg = Iterables.getOnlyElement(arguments);
|
||||
|
||||
if (arg.isDirectColumnAccess()) {
|
||||
fieldName = arg.getDirectColumn();
|
||||
expression = null;
|
||||
} else {
|
||||
fieldName = null;
|
||||
expression = arg.getExpression();
|
||||
}
|
||||
|
||||
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
|
||||
|
||||
final ValueType sumType;
|
||||
// Use 64-bit sum regardless of the type of the AVG aggregator.
|
||||
if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
|
||||
sumType = ValueType.LONG;
|
||||
} else {
|
||||
sumType = ValueType.DOUBLE;
|
||||
}
|
||||
|
||||
final String sumName = Calcites.makePrefixedName(name, "sum");
|
||||
final String countName = Calcites.makePrefixedName(name, "count");
|
||||
final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
|
||||
sumType,
|
||||
sumName,
|
||||
fieldName,
|
||||
expression,
|
||||
macroTable
|
||||
);
|
||||
final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory(
|
||||
countName,
|
||||
plannerContext,
|
||||
|
@ -119,6 +90,38 @@ public class AvgSqlAggregator implements SqlAggregator
|
|||
project
|
||||
);
|
||||
|
||||
final String fieldName;
|
||||
final String expression;
|
||||
final DruidExpression arg = Iterables.getOnlyElement(arguments);
|
||||
|
||||
|
||||
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
|
||||
final ValueType sumType;
|
||||
// Use 64-bit sum regardless of the type of the AVG aggregator.
|
||||
if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
|
||||
sumType = ValueType.LONG;
|
||||
} else {
|
||||
sumType = ValueType.DOUBLE;
|
||||
}
|
||||
|
||||
if (arg.isDirectColumnAccess()) {
|
||||
fieldName = arg.getDirectColumn();
|
||||
expression = null;
|
||||
} else {
|
||||
// if the filter or anywhere else defined a virtual column for us, re-use it
|
||||
VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression());
|
||||
fieldName = vc != null ? vc.getOutputName() : null;
|
||||
expression = vc != null ? null : arg.getExpression();
|
||||
}
|
||||
final String sumName = Calcites.makePrefixedName(name, "sum");
|
||||
final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
|
||||
sumType,
|
||||
sumName,
|
||||
fieldName,
|
||||
expression,
|
||||
macroTable
|
||||
);
|
||||
|
||||
return Aggregation.create(
|
||||
ImmutableList.of(sum, count),
|
||||
new ArithmeticPostAggregator(
|
||||
|
|
|
@ -134,7 +134,7 @@ public class CountSqlAggregator implements SqlAggregator
|
|||
} else {
|
||||
// Not COUNT(*), not distinct
|
||||
// COUNT(x) should count all non-null values of x.
|
||||
return Aggregation.create(createCountAggregatorFactory(
|
||||
AggregatorFactory theCount = createCountAggregatorFactory(
|
||||
name,
|
||||
plannerContext,
|
||||
rowSignature,
|
||||
|
@ -142,7 +142,9 @@ public class CountSqlAggregator implements SqlAggregator
|
|||
rexBuilder,
|
||||
aggregateCall,
|
||||
project
|
||||
));
|
||||
);
|
||||
|
||||
return Aggregation.create(theCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -64,9 +64,7 @@ import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
|||
import javax.annotation.Nullable;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
||||
{
|
||||
|
@ -209,9 +207,6 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
return Aggregation.create(
|
||||
Stream.of(virtualColumnRegistry.getVirtualColumn(fieldName))
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toList()),
|
||||
Collections.singletonList(
|
||||
aggregatorType.createAggregatorFactory(
|
||||
aggregatorName,
|
||||
|
|
|
@ -636,7 +636,7 @@ public class DruidQuery
|
|||
}
|
||||
|
||||
for (Aggregation aggregation : grouping.getAggregations()) {
|
||||
virtualColumns.addAll(aggregation.getVirtualColumns());
|
||||
virtualColumns.addAll(virtualColumnRegistry.findVirtualColumns(aggregation.getRequiredColumns()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,9 @@ import org.apache.druid.sql.calcite.planner.PlannerContext;
|
|||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Provides facilities to create and re-use {@link VirtualColumn} definitions for dimensions, filters, and filtered
|
||||
|
@ -128,6 +130,12 @@ public class VirtualColumnRegistry
|
|||
return virtualColumnsByName.get(virtualColumnName);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public VirtualColumn getVirtualColumnByExpression(String expression)
|
||||
{
|
||||
return virtualColumnsByExpression.get(expression);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a signature representing the base signature plus all registered virtual columns.
|
||||
*/
|
||||
|
@ -145,4 +153,15 @@ public class VirtualColumnRegistry
|
|||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a list of column names, find any corresponding {@link VirtualColumn} with the same name
|
||||
*/
|
||||
public List<VirtualColumn> findVirtualColumns(List<String> allColumns)
|
||||
{
|
||||
return allColumns.stream()
|
||||
.filter(this::isVirtualColumnDefined)
|
||||
.map(this::getVirtualColumn)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -117,7 +117,6 @@ public class GroupByRules
|
|||
if (doesMatch) {
|
||||
existingAggregationsWithSameFilter.add(
|
||||
Aggregation.create(
|
||||
existingAggregation.getVirtualColumns(),
|
||||
existingAggregation.getAggregatorFactories().stream()
|
||||
.map(factory -> ((FilteredAggregatorFactory) factory).getAggregator())
|
||||
.collect(Collectors.toList()),
|
||||
|
|
|
@ -50,6 +50,7 @@ import org.apache.druid.query.QueryException;
|
|||
import org.apache.druid.query.ResourceLimitExceededException;
|
||||
import org.apache.druid.query.TableDataSource;
|
||||
import org.apache.druid.query.UnionDataSource;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
|
||||
|
@ -112,6 +113,8 @@ import org.apache.druid.query.topn.DimensionTopNMetricSpec;
|
|||
import org.apache.druid.query.topn.InvertedTopNMetricSpec;
|
||||
import org.apache.druid.query.topn.NumericTopNMetricSpec;
|
||||
import org.apache.druid.query.topn.TopNQueryBuilder;
|
||||
import org.apache.druid.segment.VirtualColumn;
|
||||
import org.apache.druid.segment.VirtualColumns;
|
||||
import org.apache.druid.segment.column.RowSignature;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
import org.apache.druid.segment.join.JoinType;
|
||||
|
@ -18926,4 +18929,76 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
expectedResults
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCountAndAverageByConstantVirtualColumn() throws Exception
|
||||
{
|
||||
List<VirtualColumn> virtualColumns;
|
||||
List<AggregatorFactory> aggs;
|
||||
if (useDefault) {
|
||||
aggs = ImmutableList.of(
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a0"),
|
||||
not(selector("v0", null, null))
|
||||
),
|
||||
new LongSumAggregatorFactory("a1:sum", null, "325323", TestExprMacroTable.INSTANCE),
|
||||
new CountAggregatorFactory("a1:count")
|
||||
);
|
||||
virtualColumns = ImmutableList.of(
|
||||
expressionVirtualColumn("v0", "'10.1'", ValueType.STRING)
|
||||
);
|
||||
} else {
|
||||
aggs = ImmutableList.of(
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a0"),
|
||||
not(selector("v0", null, null))
|
||||
),
|
||||
new LongSumAggregatorFactory("a1:sum", "v1"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a1:count"),
|
||||
not(selector("v1", null, null))
|
||||
)
|
||||
);
|
||||
virtualColumns = ImmutableList.of(
|
||||
expressionVirtualColumn("v0", "'10.1'", ValueType.STRING),
|
||||
expressionVirtualColumn("v1", "325323", ValueType.LONG)
|
||||
);
|
||||
|
||||
}
|
||||
testQuery(
|
||||
"SELECT dim5, COUNT(dim1), AVG(l1) FROM druid.numfoo WHERE dim1 = '10.1' AND l1 = 325323 GROUP BY dim5",
|
||||
ImmutableList.of(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE3)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setDimFilter(
|
||||
and(
|
||||
selector("dim1", "10.1", null),
|
||||
selector("l1", "325323", null)
|
||||
)
|
||||
)
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setVirtualColumns(VirtualColumns.create(virtualColumns))
|
||||
.setDimensions(new DefaultDimensionSpec("dim5", "_d0", ValueType.STRING))
|
||||
.setAggregatorSpecs(aggs)
|
||||
.setPostAggregatorSpecs(
|
||||
ImmutableList.of(
|
||||
new ArithmeticPostAggregator(
|
||||
"a1",
|
||||
"quotient",
|
||||
ImmutableList.of(
|
||||
new FieldAccessPostAggregator(null, "a1:sum"),
|
||||
new FieldAccessPostAggregator(null, "a1:count")
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{"ab", 1L, 325323L}
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue