Fix avg sql aggregator (#10135)

* new average aggregator

* method to create count aggregator factory

* test everything

* update other usages

* fix style

* fix more tests

* fix datasketches tests
This commit is contained in:
Franklyn Dsouza 2020-07-08 11:38:56 -04:00 committed by GitHub
parent c776e412e0
commit 1b9aacb1cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 247 additions and 72 deletions

View File

@ -388,10 +388,20 @@ public class HllSketchSqlAggregatorTest extends CalciteTestBase
) )
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) .setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL) .setGranularity(Granularities.ALL)
.setAggregatorSpecs(Arrays.asList( .setAggregatorSpecs(
NullHandling.replaceWithDefault()
? Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"), new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count") new CountAggregatorFactory("_a0:count")
)) )
: Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
BaseCalciteQueryTest.not(BaseCalciteQueryTest.selector("a0", null, null))
)
)
)
.setPostAggregatorSpecs( .setPostAggregatorSpecs(
ImmutableList.of( ImmutableList.of(
new ArithmeticPostAggregator( new ArithmeticPostAggregator(

View File

@ -385,10 +385,20 @@ public class ThetaSketchSqlAggregatorTest extends CalciteTestBase
) )
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) .setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL) .setGranularity(Granularities.ALL)
.setAggregatorSpecs(Arrays.asList( .setAggregatorSpecs(
NullHandling.replaceWithDefault()
? Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"), new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count") new CountAggregatorFactory("_a0:count")
)) )
: Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
BaseCalciteQueryTest.not(BaseCalciteQueryTest.selector("a0", null, null))
)
)
)
.setPostAggregatorSpecs( .setPostAggregatorSpecs(
ImmutableList.of( ImmutableList.of(
new ArithmeticPostAggregator( new ArithmeticPostAggregator(

View File

@ -20,20 +20,31 @@
package org.apache.druid.sql.calcite.aggregation.builtin; package org.apache.druid.sql.calcite.aggregation.builtin;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
public class AvgSqlAggregator extends SimpleSqlAggregator import javax.annotation.Nullable;
import java.util.List;
public class AvgSqlAggregator implements SqlAggregator
{ {
@Override @Override
public SqlAggFunction calciteFunction() public SqlAggFunction calciteFunction()
@ -41,15 +52,46 @@ public class AvgSqlAggregator extends SimpleSqlAggregator
return SqlStdOperatorTable.AVG; return SqlStdOperatorTable.AVG;
} }
@Nullable
@Override @Override
Aggregation getAggregation( public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name, final String name,
final AggregateCall aggregateCall, final AggregateCall aggregateCall,
final ExprMacroTable macroTable, final Project project,
final String fieldName, final List<Aggregation> existingAggregations,
final String expression final boolean finalizeAggregations
) )
{ {
final List<DruidExpression> arguments = Aggregations.getArgumentsForSimpleAggregator(
plannerContext,
rowSignature,
aggregateCall,
project
);
if (arguments == null) {
return null;
}
final String fieldName;
final String expression;
final DruidExpression arg = Iterables.getOnlyElement(arguments);
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
expression = null;
} else {
fieldName = null;
expression = arg.getExpression();
}
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
final ValueType sumType; final ValueType sumType;
// Use 64-bit sum regardless of the type of the AVG aggregator. // Use 64-bit sum regardless of the type of the AVG aggregator.
if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) { if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
@ -67,8 +109,15 @@ public class AvgSqlAggregator extends SimpleSqlAggregator
expression, expression,
macroTable macroTable
); );
final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory(
final AggregatorFactory count = new CountAggregatorFactory(countName); countName,
plannerContext,
rowSignature,
virtualColumnRegistry,
rexBuilder,
aggregateCall,
project
);
return Aggregation.create( return Aggregation.create(
ImmutableList.of(sum, count), ImmutableList.of(sum, count),

View File

@ -28,7 +28,9 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregation;
@ -52,6 +54,41 @@ public class CountSqlAggregator implements SqlAggregator
return SqlStdOperatorTable.COUNT; return SqlStdOperatorTable.COUNT;
} }
static AggregatorFactory createCountAggregatorFactory(
final String countName,
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final AggregateCall aggregateCall,
final Project project
)
{
final RexNode rexNode = Expressions.fromFieldAccess(
rowSignature,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
if (rexNode.getType().isNullable()) {
final DimFilter nonNullFilter = Expressions.toFilter(
plannerContext,
rowSignature,
virtualColumnRegistry,
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, ImmutableList.of(rexNode))
);
if (nonNullFilter == null) {
// Don't expect this to happen.
throw new ISE("Could not create not-null filter for rexNode[%s]", rexNode);
}
return new FilteredAggregatorFactory(new CountAggregatorFactory(countName), nonNullFilter);
} else {
return new CountAggregatorFactory(countName);
}
}
@Nullable @Nullable
@Override @Override
public Aggregation toDruidAggregation( public Aggregation toDruidAggregation(
@ -96,32 +133,16 @@ public class CountSqlAggregator implements SqlAggregator
} }
} else { } else {
// Not COUNT(*), not distinct // Not COUNT(*), not distinct
// COUNT(x) should count all non-null values of x. // COUNT(x) should count all non-null values of x.
final RexNode rexNode = Expressions.fromFieldAccess( return Aggregation.create(createCountAggregatorFactory(
rowSignature, name,
project,
Iterables.getOnlyElement(aggregateCall.getArgList())
);
if (rexNode.getType().isNullable()) {
final DimFilter nonNullFilter = Expressions.toFilter(
plannerContext, plannerContext,
rowSignature, rowSignature,
virtualColumnRegistry, virtualColumnRegistry,
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, ImmutableList.of(rexNode)) rexBuilder,
); aggregateCall,
project
if (nonNullFilter == null) { ));
// Don't expect this to happen.
throw new ISE("Could not create not-null filter for rexNode[%s]", rexNode);
}
return Aggregation.create(new CountAggregatorFactory(name))
.filter(rowSignature, virtualColumnRegistry, nonNullFilter);
} else {
return Aggregation.create(new CountAggregatorFactory(name));
}
} }
} }
} }

View File

@ -244,10 +244,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity())) .setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING)) .setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING))
.setGranularity(Granularities.ALL) .setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators( .setAggregatorSpecs(
useDefault
? aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"), new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count") new CountAggregatorFactory("a0:count")
) )
: aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0:count"),
not(selector("m2", null, null))
)
)
) )
.setPostAggregatorSpecs( .setPostAggregatorSpecs(
ImmutableList.of( ImmutableList.of(
@ -313,10 +322,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity())) .setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING)) .setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING))
.setGranularity(Granularities.ALL) .setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators( .setAggregatorSpecs(
useDefault
? aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"), new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count") new CountAggregatorFactory("a0:count")
) )
: aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0:count"),
not(selector("m2", null, null))
)
)
) )
.setPostAggregatorSpecs( .setPostAggregatorSpecs(
ImmutableList.of( ImmutableList.of(
@ -390,10 +408,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setInterval(querySegmentSpec(Filtration.eternity())) .setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING)) .setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING))
.setGranularity(Granularities.ALL) .setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators( .setAggregatorSpecs(
useDefault
? aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"), new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count") new CountAggregatorFactory("a0:count")
) )
: aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0:count"),
not(selector("m2", null, null))
)
)
) )
.setPostAggregatorSpecs( .setPostAggregatorSpecs(
ImmutableList.of( ImmutableList.of(
@ -4730,11 +4757,11 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
cannotVectorize(); cannotVectorize();
testQuery( testQuery(
"SELECT COUNT(*), COUNT(cnt), COUNT(dim1), AVG(cnt), SUM(cnt), SUM(cnt) + MIN(cnt) + MAX(cnt), COUNT(dim2) FROM druid.foo", "SELECT COUNT(*), COUNT(cnt), COUNT(dim1), AVG(cnt), SUM(cnt), SUM(cnt) + MIN(cnt) + MAX(cnt), COUNT(dim2), COUNT(d1), AVG(d1) FROM druid.numfoo",
ImmutableList.of( ImmutableList.of(
Druids.newTimeseriesQueryBuilder() Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1) .dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity())) .intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL) .granularity(Granularities.ALL)
.aggregators( .aggregators(
@ -4753,7 +4780,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
new CountAggregatorFactory("a6"), new CountAggregatorFactory("a6"),
not(selector("dim2", null, null)) not(selector("dim2", null, null))
) ),
new DoubleSumAggregatorFactory("a7:sum", "d1"),
new CountAggregatorFactory("a7:count")
) )
: aggregators( : aggregators(
new CountAggregatorFactory("a0"), new CountAggregatorFactory("a0"),
@ -4766,13 +4795,25 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
not(selector("dim1", null, null)) not(selector("dim1", null, null))
), ),
new LongSumAggregatorFactory("a3:sum", "cnt"), new LongSumAggregatorFactory("a3:sum", "cnt"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a3:count"), new CountAggregatorFactory("a3:count"),
not(selector("cnt", null, null))
),
new LongSumAggregatorFactory("a4", "cnt"), new LongSumAggregatorFactory("a4", "cnt"),
new LongMinAggregatorFactory("a5", "cnt"), new LongMinAggregatorFactory("a5", "cnt"),
new LongMaxAggregatorFactory("a6", "cnt"), new LongMaxAggregatorFactory("a6", "cnt"),
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
new CountAggregatorFactory("a7"), new CountAggregatorFactory("a7"),
not(selector("dim2", null, null)) not(selector("dim2", null, null))
),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a8"),
not(selector("d1", null, null))
),
new DoubleSumAggregatorFactory("a9:sum", "d1"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a9:count"),
not(selector("d1", null, null))
) )
) )
) )
@ -4785,6 +4826,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new FieldAccessPostAggregator(null, useDefault ? "a2:count" : "a3:count") new FieldAccessPostAggregator(null, useDefault ? "a2:count" : "a3:count")
) )
), ),
new ArithmeticPostAggregator(
useDefault ? "a7" : "a9",
"quotient",
ImmutableList.of(
new FieldAccessPostAggregator(null, useDefault ? "a7:sum" : "a9:sum"),
new FieldAccessPostAggregator(null, useDefault ? "a7:count" : "a9:count")
)
),
expressionPostAgg( expressionPostAgg(
"p0", "p0",
useDefault ? "((\"a3\" + \"a4\") + \"a5\")" : "((\"a4\" + \"a5\") + \"a6\")" useDefault ? "((\"a3\" + \"a4\") + \"a5\")" : "((\"a4\" + \"a5\") + \"a6\")"
@ -4795,10 +4844,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
), ),
NullHandling.replaceWithDefault() ? NullHandling.replaceWithDefault() ?
ImmutableList.of( ImmutableList.of(
new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L} new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)}
) : ) :
ImmutableList.of( ImmutableList.of(
new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L} new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)}
) )
); );
} }
@ -6801,14 +6850,28 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
) )
.setInterval(querySegmentSpec(Filtration.eternity())) .setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL) .setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators( .setAggregatorSpecs(
useDefault
? aggregators(
new LongMaxAggregatorFactory("_a0", "a0"), new LongMaxAggregatorFactory("_a0", "a0"),
new LongMinAggregatorFactory("_a1", "a0"), new LongMinAggregatorFactory("_a1", "a0"),
new LongSumAggregatorFactory("_a2:sum", "a0"), new LongSumAggregatorFactory("_a2:sum", "a0"),
new CountAggregatorFactory("_a2:count"), new CountAggregatorFactory("_a2:count"),
new LongMaxAggregatorFactory("_a3", "d0"), new LongMaxAggregatorFactory("_a3", "d0"),
new CountAggregatorFactory("_a4") new CountAggregatorFactory("_a4")
)) )
: aggregators(
new LongMaxAggregatorFactory("_a0", "a0"),
new LongMinAggregatorFactory("_a1", "a0"),
new LongSumAggregatorFactory("_a2:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a2:count"),
not(selector("a0", null, null))
),
new LongMaxAggregatorFactory("_a3", "d0"),
new CountAggregatorFactory("_a4")
)
)
.setPostAggregatorSpecs( .setPostAggregatorSpecs(
ImmutableList.of( ImmutableList.of(
new ArithmeticPostAggregator( new ArithmeticPostAggregator(
@ -6872,10 +6935,20 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
) )
.setInterval(querySegmentSpec(Filtration.eternity())) .setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL) .setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators( .setAggregatorSpecs(
useDefault
? aggregators(
new LongSumAggregatorFactory("_a0:sum", "a0"), new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count") new CountAggregatorFactory("_a0:count")
)) )
: aggregators(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
not(selector("a0", null, null))
)
)
)
.setPostAggregatorSpecs( .setPostAggregatorSpecs(
ImmutableList.of( ImmutableList.of(
new ArithmeticPostAggregator( new ArithmeticPostAggregator(
@ -12935,11 +13008,23 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.dimension(new DefaultDimensionSpec("m1", "d0", ValueType.FLOAT)) .dimension(new DefaultDimensionSpec("m1", "d0", ValueType.FLOAT))
.filters("dim2", "a") .filters("dim2", "a")
.aggregators( .aggregators(
useDefault
? aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"), new DoubleSumAggregatorFactory("a0:sum", "m2"),
new CountAggregatorFactory("a0:count"), new CountAggregatorFactory("a0:count"),
new DoubleSumAggregatorFactory("a1", "m1"), new DoubleSumAggregatorFactory("a1", "m1"),
new DoubleSumAggregatorFactory("a2", "m2") new DoubleSumAggregatorFactory("a2", "m2")
) )
: aggregators(
new DoubleSumAggregatorFactory("a0:sum", "m2"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0:count"),
not(selector("m2", null, null))
),
new DoubleSumAggregatorFactory("a1", "m1"),
new DoubleSumAggregatorFactory("a2", "m2")
)
)
.postAggregators( .postAggregators(
new ArithmeticPostAggregator( new ArithmeticPostAggregator(
"a0", "a0",