mirror of https://github.com/apache/druid.git
Fix avg sql aggregator (#10135)
* new average aggregator * method to create count aggregator factory * test everything * update other usages * fix style * fix more tests * fix datasketches tests
This commit is contained in:
parent
c776e412e0
commit
1b9aacb1cd
|
@ -388,10 +388,20 @@ public class HllSketchSqlAggregatorTest extends CalciteTestBase
|
|||
)
|
||||
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setAggregatorSpecs(Arrays.asList(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new CountAggregatorFactory("_a0:count")
|
||||
))
|
||||
.setAggregatorSpecs(
|
||||
NullHandling.replaceWithDefault()
|
||||
? Arrays.asList(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new CountAggregatorFactory("_a0:count")
|
||||
)
|
||||
: Arrays.asList(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("_a0:count"),
|
||||
BaseCalciteQueryTest.not(BaseCalciteQueryTest.selector("a0", null, null))
|
||||
)
|
||||
)
|
||||
)
|
||||
.setPostAggregatorSpecs(
|
||||
ImmutableList.of(
|
||||
new ArithmeticPostAggregator(
|
||||
|
|
|
@ -385,10 +385,20 @@ public class ThetaSketchSqlAggregatorTest extends CalciteTestBase
|
|||
)
|
||||
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setAggregatorSpecs(Arrays.asList(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new CountAggregatorFactory("_a0:count")
|
||||
))
|
||||
.setAggregatorSpecs(
|
||||
NullHandling.replaceWithDefault()
|
||||
? Arrays.asList(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new CountAggregatorFactory("_a0:count")
|
||||
)
|
||||
: Arrays.asList(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("_a0:count"),
|
||||
BaseCalciteQueryTest.not(BaseCalciteQueryTest.selector("a0", null, null))
|
||||
)
|
||||
)
|
||||
)
|
||||
.setPostAggregatorSpecs(
|
||||
ImmutableList.of(
|
||||
new ArithmeticPostAggregator(
|
||||
|
|
|
@ -20,20 +20,31 @@
|
|||
package org.apache.druid.sql.calcite.aggregation.builtin;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Iterables;
|
||||
import org.apache.calcite.rel.core.AggregateCall;
|
||||
import org.apache.calcite.rel.core.Project;
|
||||
import org.apache.calcite.rex.RexBuilder;
|
||||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.druid.math.expr.ExprMacroTable;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
|
||||
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
|
||||
import org.apache.druid.segment.column.RowSignature;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregations;
|
||||
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import org.apache.druid.sql.calcite.expression.DruidExpression;
|
||||
import org.apache.druid.sql.calcite.planner.Calcites;
|
||||
import org.apache.druid.sql.calcite.planner.PlannerContext;
|
||||
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
|
||||
|
||||
public class AvgSqlAggregator extends SimpleSqlAggregator
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.List;
|
||||
|
||||
public class AvgSqlAggregator implements SqlAggregator
|
||||
{
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
|
@ -41,15 +52,46 @@ public class AvgSqlAggregator extends SimpleSqlAggregator
|
|||
return SqlStdOperatorTable.AVG;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
Aggregation getAggregation(
|
||||
public Aggregation toDruidAggregation(
|
||||
final PlannerContext plannerContext,
|
||||
final RowSignature rowSignature,
|
||||
final VirtualColumnRegistry virtualColumnRegistry,
|
||||
final RexBuilder rexBuilder,
|
||||
final String name,
|
||||
final AggregateCall aggregateCall,
|
||||
final ExprMacroTable macroTable,
|
||||
final String fieldName,
|
||||
final String expression
|
||||
final Project project,
|
||||
final List<Aggregation> existingAggregations,
|
||||
final boolean finalizeAggregations
|
||||
)
|
||||
{
|
||||
|
||||
final List<DruidExpression> arguments = Aggregations.getArgumentsForSimpleAggregator(
|
||||
plannerContext,
|
||||
rowSignature,
|
||||
aggregateCall,
|
||||
project
|
||||
);
|
||||
|
||||
if (arguments == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
final String fieldName;
|
||||
final String expression;
|
||||
final DruidExpression arg = Iterables.getOnlyElement(arguments);
|
||||
|
||||
if (arg.isDirectColumnAccess()) {
|
||||
fieldName = arg.getDirectColumn();
|
||||
expression = null;
|
||||
} else {
|
||||
fieldName = null;
|
||||
expression = arg.getExpression();
|
||||
}
|
||||
|
||||
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
|
||||
|
||||
final ValueType sumType;
|
||||
// Use 64-bit sum regardless of the type of the AVG aggregator.
|
||||
if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
|
||||
|
@ -67,8 +109,15 @@ public class AvgSqlAggregator extends SimpleSqlAggregator
|
|||
expression,
|
||||
macroTable
|
||||
);
|
||||
|
||||
final AggregatorFactory count = new CountAggregatorFactory(countName);
|
||||
final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory(
|
||||
countName,
|
||||
plannerContext,
|
||||
rowSignature,
|
||||
virtualColumnRegistry,
|
||||
rexBuilder,
|
||||
aggregateCall,
|
||||
project
|
||||
);
|
||||
|
||||
return Aggregation.create(
|
||||
ImmutableList.of(sum, count),
|
||||
|
|
|
@ -28,7 +28,9 @@ import org.apache.calcite.rex.RexNode;
|
|||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
||||
import org.apache.druid.java.util.common.ISE;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
|
||||
import org.apache.druid.query.filter.DimFilter;
|
||||
import org.apache.druid.segment.column.RowSignature;
|
||||
import org.apache.druid.sql.calcite.aggregation.Aggregation;
|
||||
|
@ -52,6 +54,41 @@ public class CountSqlAggregator implements SqlAggregator
|
|||
return SqlStdOperatorTable.COUNT;
|
||||
}
|
||||
|
||||
static AggregatorFactory createCountAggregatorFactory(
|
||||
final String countName,
|
||||
final PlannerContext plannerContext,
|
||||
final RowSignature rowSignature,
|
||||
final VirtualColumnRegistry virtualColumnRegistry,
|
||||
final RexBuilder rexBuilder,
|
||||
final AggregateCall aggregateCall,
|
||||
final Project project
|
||||
)
|
||||
{
|
||||
final RexNode rexNode = Expressions.fromFieldAccess(
|
||||
rowSignature,
|
||||
project,
|
||||
Iterables.getOnlyElement(aggregateCall.getArgList())
|
||||
);
|
||||
|
||||
if (rexNode.getType().isNullable()) {
|
||||
final DimFilter nonNullFilter = Expressions.toFilter(
|
||||
plannerContext,
|
||||
rowSignature,
|
||||
virtualColumnRegistry,
|
||||
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, ImmutableList.of(rexNode))
|
||||
);
|
||||
|
||||
if (nonNullFilter == null) {
|
||||
// Don't expect this to happen.
|
||||
throw new ISE("Could not create not-null filter for rexNode[%s]", rexNode);
|
||||
}
|
||||
|
||||
return new FilteredAggregatorFactory(new CountAggregatorFactory(countName), nonNullFilter);
|
||||
} else {
|
||||
return new CountAggregatorFactory(countName);
|
||||
}
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Aggregation toDruidAggregation(
|
||||
|
@ -96,32 +133,16 @@ public class CountSqlAggregator implements SqlAggregator
|
|||
}
|
||||
} else {
|
||||
// Not COUNT(*), not distinct
|
||||
|
||||
// COUNT(x) should count all non-null values of x.
|
||||
final RexNode rexNode = Expressions.fromFieldAccess(
|
||||
rowSignature,
|
||||
project,
|
||||
Iterables.getOnlyElement(aggregateCall.getArgList())
|
||||
);
|
||||
|
||||
if (rexNode.getType().isNullable()) {
|
||||
final DimFilter nonNullFilter = Expressions.toFilter(
|
||||
return Aggregation.create(createCountAggregatorFactory(
|
||||
name,
|
||||
plannerContext,
|
||||
rowSignature,
|
||||
virtualColumnRegistry,
|
||||
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, ImmutableList.of(rexNode))
|
||||
);
|
||||
|
||||
if (nonNullFilter == null) {
|
||||
// Don't expect this to happen.
|
||||
throw new ISE("Could not create not-null filter for rexNode[%s]", rexNode);
|
||||
}
|
||||
|
||||
return Aggregation.create(new CountAggregatorFactory(name))
|
||||
.filter(rowSignature, virtualColumnRegistry, nonNullFilter);
|
||||
} else {
|
||||
return Aggregation.create(new CountAggregatorFactory(name));
|
||||
}
|
||||
rexBuilder,
|
||||
aggregateCall,
|
||||
project
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -244,10 +244,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setAggregatorSpecs(aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new CountAggregatorFactory("a0:count")
|
||||
)
|
||||
.setAggregatorSpecs(
|
||||
useDefault
|
||||
? aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new CountAggregatorFactory("a0:count")
|
||||
)
|
||||
: aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a0:count"),
|
||||
not(selector("m2", null, null))
|
||||
)
|
||||
)
|
||||
)
|
||||
.setPostAggregatorSpecs(
|
||||
ImmutableList.of(
|
||||
|
@ -313,10 +322,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setAggregatorSpecs(aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new CountAggregatorFactory("a0:count")
|
||||
)
|
||||
.setAggregatorSpecs(
|
||||
useDefault
|
||||
? aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new CountAggregatorFactory("a0:count")
|
||||
)
|
||||
: aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a0:count"),
|
||||
not(selector("m2", null, null))
|
||||
)
|
||||
)
|
||||
)
|
||||
.setPostAggregatorSpecs(
|
||||
ImmutableList.of(
|
||||
|
@ -390,10 +408,19 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setDimensions(new DefaultDimensionSpec("dim2", "d0", ValueType.STRING))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setAggregatorSpecs(aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new CountAggregatorFactory("a0:count")
|
||||
)
|
||||
.setAggregatorSpecs(
|
||||
useDefault
|
||||
? aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new CountAggregatorFactory("a0:count")
|
||||
)
|
||||
: aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a0:count"),
|
||||
not(selector("m2", null, null))
|
||||
)
|
||||
)
|
||||
)
|
||||
.setPostAggregatorSpecs(
|
||||
ImmutableList.of(
|
||||
|
@ -4730,11 +4757,11 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
cannotVectorize();
|
||||
|
||||
testQuery(
|
||||
"SELECT COUNT(*), COUNT(cnt), COUNT(dim1), AVG(cnt), SUM(cnt), SUM(cnt) + MIN(cnt) + MAX(cnt), COUNT(dim2) FROM druid.foo",
|
||||
"SELECT COUNT(*), COUNT(cnt), COUNT(dim1), AVG(cnt), SUM(cnt), SUM(cnt) + MIN(cnt) + MAX(cnt), COUNT(dim2), COUNT(d1), AVG(d1) FROM druid.numfoo",
|
||||
ImmutableList.of(
|
||||
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.dataSource(CalciteTests.DATASOURCE3)
|
||||
.intervals(querySegmentSpec(Filtration.eternity()))
|
||||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
|
@ -4753,7 +4780,9 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a6"),
|
||||
not(selector("dim2", null, null))
|
||||
)
|
||||
),
|
||||
new DoubleSumAggregatorFactory("a7:sum", "d1"),
|
||||
new CountAggregatorFactory("a7:count")
|
||||
)
|
||||
: aggregators(
|
||||
new CountAggregatorFactory("a0"),
|
||||
|
@ -4766,13 +4795,25 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
not(selector("dim1", null, null))
|
||||
),
|
||||
new LongSumAggregatorFactory("a3:sum", "cnt"),
|
||||
new CountAggregatorFactory("a3:count"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a3:count"),
|
||||
not(selector("cnt", null, null))
|
||||
),
|
||||
new LongSumAggregatorFactory("a4", "cnt"),
|
||||
new LongMinAggregatorFactory("a5", "cnt"),
|
||||
new LongMaxAggregatorFactory("a6", "cnt"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a7"),
|
||||
not(selector("dim2", null, null))
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a8"),
|
||||
not(selector("d1", null, null))
|
||||
),
|
||||
new DoubleSumAggregatorFactory("a9:sum", "d1"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a9:count"),
|
||||
not(selector("d1", null, null))
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -4785,6 +4826,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new FieldAccessPostAggregator(null, useDefault ? "a2:count" : "a3:count")
|
||||
)
|
||||
),
|
||||
new ArithmeticPostAggregator(
|
||||
useDefault ? "a7" : "a9",
|
||||
"quotient",
|
||||
ImmutableList.of(
|
||||
new FieldAccessPostAggregator(null, useDefault ? "a7:sum" : "a9:sum"),
|
||||
new FieldAccessPostAggregator(null, useDefault ? "a7:count" : "a9:count")
|
||||
)
|
||||
),
|
||||
expressionPostAgg(
|
||||
"p0",
|
||||
useDefault ? "((\"a3\" + \"a4\") + \"a5\")" : "((\"a4\" + \"a5\") + \"a6\")"
|
||||
|
@ -4795,10 +4844,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
),
|
||||
NullHandling.replaceWithDefault() ?
|
||||
ImmutableList.of(
|
||||
new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L}
|
||||
new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)}
|
||||
) :
|
||||
ImmutableList.of(
|
||||
new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L}
|
||||
new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -6801,14 +6850,28 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setAggregatorSpecs(aggregators(
|
||||
new LongMaxAggregatorFactory("_a0", "a0"),
|
||||
new LongMinAggregatorFactory("_a1", "a0"),
|
||||
new LongSumAggregatorFactory("_a2:sum", "a0"),
|
||||
new CountAggregatorFactory("_a2:count"),
|
||||
new LongMaxAggregatorFactory("_a3", "d0"),
|
||||
new CountAggregatorFactory("_a4")
|
||||
))
|
||||
.setAggregatorSpecs(
|
||||
useDefault
|
||||
? aggregators(
|
||||
new LongMaxAggregatorFactory("_a0", "a0"),
|
||||
new LongMinAggregatorFactory("_a1", "a0"),
|
||||
new LongSumAggregatorFactory("_a2:sum", "a0"),
|
||||
new CountAggregatorFactory("_a2:count"),
|
||||
new LongMaxAggregatorFactory("_a3", "d0"),
|
||||
new CountAggregatorFactory("_a4")
|
||||
)
|
||||
: aggregators(
|
||||
new LongMaxAggregatorFactory("_a0", "a0"),
|
||||
new LongMinAggregatorFactory("_a1", "a0"),
|
||||
new LongSumAggregatorFactory("_a2:sum", "a0"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("_a2:count"),
|
||||
not(selector("a0", null, null))
|
||||
),
|
||||
new LongMaxAggregatorFactory("_a3", "d0"),
|
||||
new CountAggregatorFactory("_a4")
|
||||
)
|
||||
)
|
||||
.setPostAggregatorSpecs(
|
||||
ImmutableList.of(
|
||||
new ArithmeticPostAggregator(
|
||||
|
@ -6872,10 +6935,20 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setAggregatorSpecs(aggregators(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new CountAggregatorFactory("_a0:count")
|
||||
))
|
||||
.setAggregatorSpecs(
|
||||
useDefault
|
||||
? aggregators(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new CountAggregatorFactory("_a0:count")
|
||||
)
|
||||
: aggregators(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("_a0:count"),
|
||||
not(selector("a0", null, null))
|
||||
)
|
||||
)
|
||||
)
|
||||
.setPostAggregatorSpecs(
|
||||
ImmutableList.of(
|
||||
new ArithmeticPostAggregator(
|
||||
|
@ -12935,10 +13008,22 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.dimension(new DefaultDimensionSpec("m1", "d0", ValueType.FLOAT))
|
||||
.filters("dim2", "a")
|
||||
.aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new CountAggregatorFactory("a0:count"),
|
||||
new DoubleSumAggregatorFactory("a1", "m1"),
|
||||
new DoubleSumAggregatorFactory("a2", "m2")
|
||||
useDefault
|
||||
? aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new CountAggregatorFactory("a0:count"),
|
||||
new DoubleSumAggregatorFactory("a1", "m1"),
|
||||
new DoubleSumAggregatorFactory("a2", "m2")
|
||||
)
|
||||
: aggregators(
|
||||
new DoubleSumAggregatorFactory("a0:sum", "m2"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a0:count"),
|
||||
not(selector("m2", null, null))
|
||||
),
|
||||
new DoubleSumAggregatorFactory("a1", "m1"),
|
||||
new DoubleSumAggregatorFactory("a2", "m2")
|
||||
)
|
||||
)
|
||||
.postAggregators(
|
||||
new ArithmeticPostAggregator(
|
||||
|
|
Loading…
Reference in New Issue