Fix avg sql aggregator (#10135)

* new average aggregator

* method to create count aggregator factory

* test everything

* update other usages

* fix style

* fix more tests

* fix datasketches tests
This commit is contained in:
Franklyn Dsouza 2020-07-08 11:38:56 -04:00 committed by GitHub
parent c776e412e0
commit 1b9aacb1cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 247 additions and 72 deletions

View File

@ -388,10 +388,20 @@ public class HllSketchSqlAggregatorTest extends CalciteTestBase
)
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
))
.setAggregatorSpecs(
NullHandling.replaceWithDefault()
? Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
: Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
BaseCalciteQueryTest.not(BaseCalciteQueryTest.selector("a0", null, null))
)
)
)
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(

View File

@ -385,10 +385,20 @@ public class ThetaSketchSqlAggregatorTest extends CalciteTestBase
)
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
))
.setAggregatorSpecs(
NullHandling.replaceWithDefault()
? Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
: Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
BaseCalciteQueryTest.not(BaseCalciteQueryTest.selector("a0", null, null))
)
)
)
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(

View File

@ -20,20 +20,31 @@
package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
public class AvgSqlAggregator extends SimpleSqlAggregator
import javax.annotation.Nullable;
import java.util.List;
public class AvgSqlAggregator implements SqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
@ -41,15 +52,46 @@ public class AvgSqlAggregator extends SimpleSqlAggregator
return SqlStdOperatorTable.AVG;
}
@Nullable
@Override
Aggregation getAggregation(
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final ExprMacroTable macroTable,
final String fieldName,
final String expression
final Project project,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
final List<DruidExpression> arguments = Aggregations.getArgumentsForSimpleAggregator(
plannerContext,
rowSignature,
aggregateCall,
project
);
if (arguments == null) {
return null;
}
final String fieldName;
final String expression;
final DruidExpression arg = Iterables.getOnlyElement(arguments);
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
expression = null;
} else {
fieldName = null;
expression = arg.getExpression();
}
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
final ValueType sumType;
// Use 64-bit sum regardless of the type of the AVG aggregator.
if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
@ -67,8 +109,15 @@ public class AvgSqlAggregator extends SimpleSqlAggregator
expression,
macroTable
);
final AggregatorFactory count = new CountAggregatorFactory(countName);
final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory(
countName,
plannerContext,
rowSignature,
virtualColumnRegistry,
rexBuilder,
aggregateCall,
project
);
return Aggregation.create(
ImmutableList.of(sum, count),

View File

@ -28,7 +28,9 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
@ -52,6 +54,41 @@ public class CountSqlAggregator implements SqlAggregator
return SqlStdOperatorTable.COUNT;
}
static AggregatorFactory createCountAggregatorFactory(
final String countName,
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final AggregateCall aggregateCall,
final Project project
)
{
final RexNode rexNode = Expressions.fromFieldAccess(
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
if (rexNode.getType().isNullable()) {
final DimFilter nonNullFilter = Expressions.toFilter(
plannerContext,
rowSignature,
virtualColumnRegistry,
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, ImmutableList.of(rexNode))
);
if (nonNullFilter == null) {
// Don't expect this to happen.
throw new ISE("Could not create not-null filter for rexNode[%s]", rexNode);
}
return new FilteredAggregatorFactory(new CountAggregatorFactory(countName), nonNullFilter);
} else {
return new CountAggregatorFactory(countName);
}
}
@Nullable
@Override
public Aggregation toDruidAggregation(
@ -96,32 +133,16 @@ public class CountSqlAggregator implements SqlAggregator
}
} else {
// Not COUNT(*), not distinct
// COUNT(x) should count all non-null values of x.
final RexNode rexNode = Expressions.fromFieldAccess(
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
if (rexNode.getType().isNullable()) {
final DimFilter nonNullFilter = Expressions.toFilter(
return Aggregation.create(createCountAggregatorFactory(
name,
plannerContext,
rowSignature,
virtualColumnRegistry,
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, ImmutableList.of(rexNode))
);
if (nonNullFilter == null) {
// Don't expect this to happen.
throw new ISE("Could not create not-null filter for rexNode[%s]", rexNode);
}
return Aggregation.create(new CountAggregatorFactory(name))
.filter(rowSignature, virtualColumnRegistry, nonNullFilter);
} else {
return Aggregation.create(new CountAggregatorFactory(name));
}
rexBuilder,
aggregateCall,
project
));
}
}
}

View File

@ -244,10 +244,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count")
)
.setAggregatorSpecs(
useDefault
? aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count")
)
: aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0:count"),
not(selector("m2", null, null))
)
)
)
.setPostAggregatorSpecs(
ImmutableList.of(
@ -313,10 +322,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count")
)
.setAggregatorSpecs(
useDefault
? aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count")
)
: aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0:count"),
not(selector("m2", null, null))
)
)
)
.setPostAggregatorSpecs(
ImmutableList.of(
@ -390,10 +408,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count")
)
.setAggregatorSpecs(
useDefault
? aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count")
)
: aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0:count"),
not(selector("m2", null, null))
)
)
)
.setPostAggregatorSpecs(
ImmutableList.of(
@ -4730,11 +4757,11 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
cannotVectorize();
testQuery(
"SELECT COUNT(*), COUNT(cnt), COUNT(dim1), AVG(cnt), SUM(cnt), SUM(cnt) + MIN(cnt) + MAX(cnt), COUNT(dim2) FROM druid.foo",
"SELECT COUNT(*), COUNT(cnt), COUNT(dim1), AVG(cnt), SUM(cnt), SUM(cnt) + MIN(cnt) + MAX(cnt), COUNT(dim2), COUNT(d1), AVG(d1) FROM druid.numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(
@ -4753,7 +4780,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new FilteredAggregatorFactory(
new CountAggregatorFactory("a6"),
not(selector("dim2", null, null))
)
),
new DoubleSumAggregatorFactory("a7:sum", "d1"),
new CountAggregatorFactory("a7:count")
)
: aggregators(
new CountAggregatorFactory("a0"),
@ -4766,13 +4795,25 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
not(selector("dim1", null, null))
),
new LongSumAggregatorFactory("a3:sum", "cnt"),
new CountAggregatorFactory("a3:count"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a3:count"),
not(selector("cnt", null, null))
),
new LongSumAggregatorFactory("a4", "cnt"),
new LongMinAggregatorFactory("a5", "cnt"),
new LongMaxAggregatorFactory("a6", "cnt"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a7"),
not(selector("dim2", null, null))
),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a8"),
not(selector("d1", null, null))
),
new DoubleSumAggregatorFactory("a9:sum", "d1"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a9:count"),
not(selector("d1", null, null))
)
)
)
@ -4785,6 +4826,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new FieldAccessPostAggregator(null, useDefault ? "a2:count" : "a3:count")
)
),
new ArithmeticPostAggregator(
useDefault ? "a7" : "a9",
"quotient",
ImmutableList.of(
new FieldAccessPostAggregator(null, useDefault ? "a7:sum" : "a9:sum"),
new FieldAccessPostAggregator(null, useDefault ? "a7:count" : "a9:count")
)
),
expressionPostAgg(
"p0",
useDefault ? "((\"a3\" + \"a4\") + \"a5\")" : "((\"a4\" + \"a5\") + \"a6\")"
@ -4795,10 +4844,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
),
NullHandling.replaceWithDefault() ?
ImmutableList.of(
new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L}
new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)}
) :
ImmutableList.of(
new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L}
new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)}
)
);
}
@ -6801,14 +6850,28 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
new LongMaxAggregatorFactory("_a0", "a0"),
new LongMinAggregatorFactory("_a1", "a0"),
new LongSumAggregatorFactory("_a2:sum", "a0"),
new CountAggregatorFactory("_a2:count"),
new LongMaxAggregatorFactory("_a3", "d0"),
new CountAggregatorFactory("_a4")
))
.setAggregatorSpecs(
useDefault
? aggregators(
new LongMaxAggregatorFactory("_a0", "a0"),
new LongMinAggregatorFactory("_a1", "a0"),
new LongSumAggregatorFactory("_a2:sum", "a0"),
new CountAggregatorFactory("_a2:count"),
new LongMaxAggregatorFactory("_a3", "d0"),
new CountAggregatorFactory("_a4")
)
: aggregators(
new LongMaxAggregatorFactory("_a0", "a0"),
new LongMinAggregatorFactory("_a1", "a0"),
new LongSumAggregatorFactory("_a2:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a2:count"),
not(selector("a0", null, null))
),
new LongMaxAggregatorFactory("_a3", "d0"),
new CountAggregatorFactory("_a4")
)
)
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(
@ -6872,10 +6935,20 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
))
.setAggregatorSpecs(
useDefault
? aggregators(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
: aggregators(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
not(selector("a0", null, null))
)
)
)
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(
@ -12935,10 +13008,22 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.dimension(new DefaultDimensionSpec("m1", "d0", ValueType.FLOAT))
.filters("dim2", "a")
.aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count"),
new DoubleSumAggregatorFactory("a1", "m1"),
new DoubleSumAggregatorFactory("a2", "m2")
useDefault
? aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count"),
new DoubleSumAggregatorFactory("a1", "m1"),
new DoubleSumAggregatorFactory("a2", "m2")
)
: aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0:count"),
not(selector("m2", null, null))
),
new DoubleSumAggregatorFactory("a1", "m1"),
new DoubleSumAggregatorFactory("a2", "m2")
)
)
.postAggregators(
new ArithmeticPostAggregator(