From 1b9aacb1cd1d16b23de2f4485e1398640e97f8bd Mon Sep 17 00:00:00 2001 From: Franklyn Dsouza Date: Wed, 8 Jul 2020 11:38:56 -0400 Subject: [PATCH] Fix avg sql aggregator (#10135) * new average aggregator * method to create count aggregator factory * test everything * update other usages * fix style * fix more tests * fix datasketches tests --- .../hll/sql/HllSketchSqlAggregatorTest.java | 18 ++- .../sql/ThetaSketchSqlAggregatorTest.java | 18 ++- .../aggregation/builtin/AvgSqlAggregator.java | 65 +++++++- .../builtin/CountSqlAggregator.java | 65 +++++--- .../druid/sql/calcite/CalciteQueryTest.java | 153 ++++++++++++++---- 5 files changed, 247 insertions(+), 72 deletions(-) diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index 6f4c8b6a9aa..039801c5efc 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -388,10 +388,20 @@ public class HllSketchSqlAggregatorTest extends CalciteTestBase ) .setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) .setGranularity(Granularities.ALL) - .setAggregatorSpecs(Arrays.asList( - new LongSumAggregatorFactory("_a0:sum", "a0"), - new CountAggregatorFactory("_a0:count") - )) + .setAggregatorSpecs( + NullHandling.replaceWithDefault() + ? Arrays.asList( + new LongSumAggregatorFactory("_a0:sum", "a0"), + new CountAggregatorFactory("_a0:count") + ) + : Arrays.asList( + new LongSumAggregatorFactory("_a0:sum", "a0"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("_a0:count"), + BaseCalciteQueryTest.not(BaseCalciteQueryTest.selector("a0", null, null)) + ) + ) + ) .setPostAggregatorSpecs( ImmutableList.of( new ArithmeticPostAggregator( diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java index 201380a1b20..9201c91b994 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java @@ -385,10 +385,20 @@ public class ThetaSketchSqlAggregatorTest extends CalciteTestBase ) .setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) .setGranularity(Granularities.ALL) - .setAggregatorSpecs(Arrays.asList( - new LongSumAggregatorFactory("_a0:sum", "a0"), - new CountAggregatorFactory("_a0:count") - )) + .setAggregatorSpecs( + NullHandling.replaceWithDefault() + ? Arrays.asList( + new LongSumAggregatorFactory("_a0:sum", "a0"), + new CountAggregatorFactory("_a0:count") + ) + : Arrays.asList( + new LongSumAggregatorFactory("_a0:sum", "a0"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("_a0:count"), + BaseCalciteQueryTest.not(BaseCalciteQueryTest.selector("a0", null, null)) + ) + ) + ) .setPostAggregatorSpecs( ImmutableList.of( new ArithmeticPostAggregator( diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java index 2761d0c3e56..3b973443668 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java @@ -20,20 +20,31 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.aggregation.AggregatorFactory; -import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; +import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; +import org.apache.druid.sql.calcite.aggregation.Aggregations; +import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; -public class AvgSqlAggregator extends SimpleSqlAggregator +import javax.annotation.Nullable; +import java.util.List; + +public class AvgSqlAggregator implements SqlAggregator { @Override public SqlAggFunction calciteFunction() @@ -41,15 +52,46 @@ public class AvgSqlAggregator extends SimpleSqlAggregator return SqlStdOperatorTable.AVG; } + @Nullable @Override - Aggregation getAggregation( + public Aggregation toDruidAggregation( + final PlannerContext plannerContext, + final RowSignature rowSignature, + final VirtualColumnRegistry virtualColumnRegistry, + final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final ExprMacroTable macroTable, - final String fieldName, - final String expression + final Project project, + final List existingAggregations, + final boolean finalizeAggregations ) { + + final List arguments = Aggregations.getArgumentsForSimpleAggregator( + plannerContext, + rowSignature, + aggregateCall, + project + ); + + if (arguments == null) { + return null; + } + + final String fieldName; + final String expression; + final DruidExpression arg = Iterables.getOnlyElement(arguments); + + if (arg.isDirectColumnAccess()) { + fieldName = arg.getDirectColumn(); + expression = null; + } else { + fieldName = null; + expression = arg.getExpression(); + } + + final ExprMacroTable macroTable = plannerContext.getExprMacroTable(); + final ValueType sumType; // Use 64-bit sum regardless of the type of the AVG aggregator. if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) { @@ -67,8 +109,15 @@ public class AvgSqlAggregator extends SimpleSqlAggregator expression, macroTable ); - - final AggregatorFactory count = new CountAggregatorFactory(countName); + final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory( + countName, + plannerContext, + rowSignature, + virtualColumnRegistry, + rexBuilder, + aggregateCall, + project + ); return Aggregation.create( ImmutableList.of(sum, count), diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java index 6bf8b601be9..f6747986df9 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java @@ -28,7 +28,9 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; @@ -52,6 +54,41 @@ public class CountSqlAggregator implements SqlAggregator return SqlStdOperatorTable.COUNT; } + static AggregatorFactory createCountAggregatorFactory( + final String countName, + final PlannerContext plannerContext, + final RowSignature rowSignature, + final VirtualColumnRegistry virtualColumnRegistry, + final RexBuilder rexBuilder, + final AggregateCall aggregateCall, + final Project project + ) + { + final RexNode rexNode = Expressions.fromFieldAccess( + rowSignature, + project, + Iterables.getOnlyElement(aggregateCall.getArgList()) + ); + + if (rexNode.getType().isNullable()) { + final DimFilter nonNullFilter = Expressions.toFilter( + plannerContext, + rowSignature, + virtualColumnRegistry, + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, ImmutableList.of(rexNode)) + ); + + if (nonNullFilter == null) { + // Don't expect this to happen. + throw new ISE("Could not create not-null filter for rexNode[%s]", rexNode); + } + + return new FilteredAggregatorFactory(new CountAggregatorFactory(countName), nonNullFilter); + } else { + return new CountAggregatorFactory(countName); + } + } + @Nullable @Override public Aggregation toDruidAggregation( @@ -96,32 +133,16 @@ public class CountSqlAggregator implements SqlAggregator } } else { // Not COUNT(*), not distinct - // COUNT(x) should count all non-null values of x. - final RexNode rexNode = Expressions.fromFieldAccess( - rowSignature, - project, - Iterables.getOnlyElement(aggregateCall.getArgList()) - ); - - if (rexNode.getType().isNullable()) { - final DimFilter nonNullFilter = Expressions.toFilter( + return Aggregation.create(createCountAggregatorFactory( + name, plannerContext, rowSignature, virtualColumnRegistry, - rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, ImmutableList.of(rexNode)) - ); - - if (nonNullFilter == null) { - // Don't expect this to happen. - throw new ISE("Could not create not-null filter for rexNode[%s]", rexNode); - } - - return Aggregation.create(new CountAggregatorFactory(name)) - .filter(rowSignature, virtualColumnRegistry, nonNullFilter); - } else { - return Aggregation.create(new CountAggregatorFactory(name)); - } + rexBuilder, + aggregateCall, + project + )); } } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index e51e81d6454..deffe204d42 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -244,10 +244,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .setInterval(querySegmentSpec(Filtration.eternity())) .setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING)) .setGranularity(Granularities.ALL) - .setAggregatorSpecs(aggregators( - new DoubleSumAggregatorFactory("a0:sum", "m2"), - new CountAggregatorFactory("a0:count") - ) + .setAggregatorSpecs( + useDefault + ? aggregators( + new DoubleSumAggregatorFactory("a0:sum", "m2"), + new CountAggregatorFactory("a0:count") + ) + : aggregators( + new DoubleSumAggregatorFactory("a0:sum", "m2"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0:count"), + not(selector("m2", null, null)) + ) + ) ) .setPostAggregatorSpecs( ImmutableList.of( @@ -313,10 +322,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .setInterval(querySegmentSpec(Filtration.eternity())) .setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING)) .setGranularity(Granularities.ALL) - .setAggregatorSpecs(aggregators( - new DoubleSumAggregatorFactory("a0:sum", "m2"), - new CountAggregatorFactory("a0:count") - ) + .setAggregatorSpecs( + useDefault + ? aggregators( + new DoubleSumAggregatorFactory("a0:sum", "m2"), + new CountAggregatorFactory("a0:count") + ) + : aggregators( + new DoubleSumAggregatorFactory("a0:sum", "m2"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0:count"), + not(selector("m2", null, null)) + ) + ) ) .setPostAggregatorSpecs( ImmutableList.of( @@ -390,10 +408,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .setInterval(querySegmentSpec(Filtration.eternity())) .setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING)) .setGranularity(Granularities.ALL) - .setAggregatorSpecs(aggregators( - new DoubleSumAggregatorFactory("a0:sum", "m2"), - new CountAggregatorFactory("a0:count") - ) + .setAggregatorSpecs( + useDefault + ? aggregators( + new DoubleSumAggregatorFactory("a0:sum", "m2"), + new CountAggregatorFactory("a0:count") + ) + : aggregators( + new DoubleSumAggregatorFactory("a0:sum", "m2"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0:count"), + not(selector("m2", null, null)) + ) + ) ) .setPostAggregatorSpecs( ImmutableList.of( @@ -4730,11 +4757,11 @@ public class CalciteQueryTest extends BaseCalciteQueryTest cannotVectorize(); testQuery( - "SELECT COUNT(*), COUNT(cnt), COUNT(dim1), AVG(cnt), SUM(cnt), SUM(cnt) + MIN(cnt) + MAX(cnt), COUNT(dim2) FROM druid.foo", + "SELECT COUNT(*), COUNT(cnt), COUNT(dim1), AVG(cnt), SUM(cnt), SUM(cnt) + MIN(cnt) + MAX(cnt), COUNT(dim2), COUNT(d1), AVG(d1) FROM druid.numfoo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() - .dataSource(CalciteTests.DATASOURCE1) + .dataSource(CalciteTests.DATASOURCE3) .intervals(querySegmentSpec(Filtration.eternity())) .granularity(Granularities.ALL) .aggregators( @@ -4753,7 +4780,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest new FilteredAggregatorFactory( new CountAggregatorFactory("a6"), not(selector("dim2", null, null)) - ) + ), + new DoubleSumAggregatorFactory("a7:sum", "d1"), + new CountAggregatorFactory("a7:count") ) : aggregators( new CountAggregatorFactory("a0"), @@ -4766,13 +4795,25 @@ public class CalciteQueryTest extends BaseCalciteQueryTest not(selector("dim1", null, null)) ), new LongSumAggregatorFactory("a3:sum", "cnt"), - new CountAggregatorFactory("a3:count"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a3:count"), + not(selector("cnt", null, null)) + ), new LongSumAggregatorFactory("a4", "cnt"), new LongMinAggregatorFactory("a5", "cnt"), new LongMaxAggregatorFactory("a6", "cnt"), new FilteredAggregatorFactory( new CountAggregatorFactory("a7"), not(selector("dim2", null, null)) + ), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a8"), + not(selector("d1", null, null)) + ), + new DoubleSumAggregatorFactory("a9:sum", "d1"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a9:count"), + not(selector("d1", null, null)) ) ) ) @@ -4785,6 +4826,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest new FieldAccessPostAggregator(null, useDefault ? "a2:count" : "a3:count") ) ), + new ArithmeticPostAggregator( + useDefault ? "a7" : "a9", + "quotient", + ImmutableList.of( + new FieldAccessPostAggregator(null, useDefault ? "a7:sum" : "a9:sum"), + new FieldAccessPostAggregator(null, useDefault ? "a7:count" : "a9:count") + ) + ), expressionPostAgg( "p0", useDefault ? "((\"a3\" + \"a4\") + \"a5\")" : "((\"a4\" + \"a5\") + \"a6\")" @@ -4795,10 +4844,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ), NullHandling.replaceWithDefault() ? ImmutableList.of( - new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L} + new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)} ) : ImmutableList.of( - new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L} + new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)} ) ); } @@ -6801,14 +6850,28 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ) .setInterval(querySegmentSpec(Filtration.eternity())) .setGranularity(Granularities.ALL) - .setAggregatorSpecs(aggregators( - new LongMaxAggregatorFactory("_a0", "a0"), - new LongMinAggregatorFactory("_a1", "a0"), - new LongSumAggregatorFactory("_a2:sum", "a0"), - new CountAggregatorFactory("_a2:count"), - new LongMaxAggregatorFactory("_a3", "d0"), - new CountAggregatorFactory("_a4") - )) + .setAggregatorSpecs( + useDefault + ? aggregators( + new LongMaxAggregatorFactory("_a0", "a0"), + new LongMinAggregatorFactory("_a1", "a0"), + new LongSumAggregatorFactory("_a2:sum", "a0"), + new CountAggregatorFactory("_a2:count"), + new LongMaxAggregatorFactory("_a3", "d0"), + new CountAggregatorFactory("_a4") + ) + : aggregators( + new LongMaxAggregatorFactory("_a0", "a0"), + new LongMinAggregatorFactory("_a1", "a0"), + new LongSumAggregatorFactory("_a2:sum", "a0"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("_a2:count"), + not(selector("a0", null, null)) + ), + new LongMaxAggregatorFactory("_a3", "d0"), + new CountAggregatorFactory("_a4") + ) + ) .setPostAggregatorSpecs( ImmutableList.of( new ArithmeticPostAggregator( @@ -6872,10 +6935,20 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ) .setInterval(querySegmentSpec(Filtration.eternity())) .setGranularity(Granularities.ALL) - .setAggregatorSpecs(aggregators( - new LongSumAggregatorFactory("_a0:sum", "a0"), - new CountAggregatorFactory("_a0:count") - )) + .setAggregatorSpecs( + useDefault + ? aggregators( + new LongSumAggregatorFactory("_a0:sum", "a0"), + new CountAggregatorFactory("_a0:count") + ) + : aggregators( + new LongSumAggregatorFactory("_a0:sum", "a0"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("_a0:count"), + not(selector("a0", null, null)) + ) + ) + ) .setPostAggregatorSpecs( ImmutableList.of( new ArithmeticPostAggregator( @@ -12935,10 +13008,22 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .dimension(new DefaultDimensionSpec("m1", "d0", ValueType.FLOAT)) .filters("dim2", "a") .aggregators( - new DoubleSumAggregatorFactory("a0:sum", "m2"), - new CountAggregatorFactory("a0:count"), - new DoubleSumAggregatorFactory("a1", "m1"), - new DoubleSumAggregatorFactory("a2", "m2") + useDefault + ? aggregators( + new DoubleSumAggregatorFactory("a0:sum", "m2"), + new CountAggregatorFactory("a0:count"), + new DoubleSumAggregatorFactory("a1", "m1"), + new DoubleSumAggregatorFactory("a2", "m2") + ) + : aggregators( + new DoubleSumAggregatorFactory("a0:sum", "m2"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0:count"), + not(selector("m2", null, null)) + ), + new DoubleSumAggregatorFactory("a1", "m1"), + new DoubleSumAggregatorFactory("a2", "m2") + ) ) .postAggregators( new ArithmeticPostAggregator(