diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index 349f1a57d1c..20ea97aec3e 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -369,7 +369,7 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest final List expectedResults = ImmutableList.of( new Object[]{ - 1L + 1.0 } ); @@ -429,11 +429,11 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest .setAggregatorSpecs( NullHandling.replaceWithDefault() ? Arrays.asList( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new CountAggregatorFactory("_a0:count") ) : Arrays.asList( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new FilteredAggregatorFactory( new CountAggregatorFactory("_a0:count"), notNull("a0") diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java index 3946ce558b1..3a079e06478 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java @@ -278,7 +278,7 @@ public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest final List expectedResults = ImmutableList.of( new Object[]{ - 1L + 1.0 } ); @@ -334,11 +334,11 @@ public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest .setAggregatorSpecs( NullHandling.replaceWithDefault() ? Arrays.asList( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new CountAggregatorFactory("_a0:count") ) : Arrays.asList( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new FilteredAggregatorFactory( new CountAggregatorFactory("_a0:count"), notNull("a0") diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java index 0b1562eb83d..ee8c469c3b8 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java @@ -30,6 +30,7 @@ import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -60,17 +61,17 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator private static final String STDDEV_NAME = "STDDEV"; private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(VARIANCE_NAME); + buildSqlVarianceAggFunction(VARIANCE_NAME); private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.VAR_POP.name()); + buildSqlVarianceAggFunction(SqlKind.VAR_POP.name()); private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name()); + buildSqlVarianceAggFunction(SqlKind.VAR_SAMP.name()); private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(STDDEV_NAME); + buildSqlVarianceAggFunction(STDDEV_NAME); private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name()); + buildSqlVarianceAggFunction(SqlKind.STDDEV_POP.name()); private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name()); + buildSqlVarianceAggFunction(SqlKind.STDDEV_SAMP.name()); @Nullable @Override @@ -160,14 +161,15 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator } /** - * Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction} - * but with an operand type that accepts variance aggregator objects in addition to numeric inputs. + * Creates a {@link SqlAggFunction} + * + * It accepts variance aggregator objects in addition to numeric inputs. */ - private static SqlAggFunction buildSqlAvgAggFunction(String name) + private static SqlAggFunction buildSqlVarianceAggFunction(String name) { return OperatorConversions .aggregatorBuilder(name) - .returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION) + .returnTypeInference(ReturnTypes.explicit(SqlTypeName.DOUBLE)) .operandTypeChecker( OperandTypes.or( OperandTypes.NUMERIC, diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java index fe68b2737ef..e45a9378496 100644 --- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java @@ -171,8 +171,8 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest final List expectedResults = ImmutableList.of( new Object[]{ holder1.getVariance(true), - holder2.getVariance(true).doubleValue(), - holder3.getVariance(true).longValue() + holder2.getVariance(true), + holder3.getVariance(true) } ); testQuery( @@ -219,7 +219,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest new Object[] { holder1.getVariance(false), holder2.getVariance(false).doubleValue(), - holder3.getVariance(false).longValue(), + holder3.getVariance(false), } ); testQuery( @@ -266,7 +266,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest new Object[] { Math.sqrt(holder1.getVariance(true)), Math.sqrt(holder2.getVariance(true)), - (long) Math.sqrt(holder3.getVariance(true)), + Math.sqrt(holder3.getVariance(true)), } ); @@ -321,7 +321,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest new Object[]{ Math.sqrt(holder1.getVariance(false)), Math.sqrt(holder2.getVariance(false)), - (long) Math.sqrt(holder3.getVariance(false)), + Math.sqrt(holder3.getVariance(false)), } ); @@ -374,7 +374,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest new Object[]{ Math.sqrt(holder1.getVariance(false)), Math.sqrt(holder2.getVariance(false)), - (long) Math.sqrt(holder3.getVariance(false)), + Math.sqrt(holder3.getVariance(false)), } ); @@ -543,7 +543,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest ), ImmutableList.of( NullHandling.replaceWithDefault() - ? new Object[]{0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L} + ? new Object[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0} : new Object[]{null, null, null, null, null, null, null, null} ) ); @@ -623,7 +623,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest ), ImmutableList.of( NullHandling.replaceWithDefault() - ? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L} + ? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0} : new Object[]{"a", null, null, null, null, null, null, null, null} ) ); @@ -688,9 +688,9 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest Assert.assertEquals(expectedResult.length, result.length); for (int j = 0; j < expectedResult.length; j++) { if (expectedResult[j] instanceof Float) { - Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10); + Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-5); } else if (expectedResult[j] instanceof Double) { - Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10); + Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-5); } else { Assert.assertEquals(expectedResult[j], result[j]); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidTypeSystem.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidTypeSystem.java index d3d09f7bdf3..dcba20ee6c4 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidTypeSystem.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidTypeSystem.java @@ -124,13 +124,7 @@ public class DruidTypeSystem implements RelDataTypeSystem final RelDataType argumentType ) { - // Widen all averages to 64-bits regardless of the size of the inputs. - - if (SqlTypeName.INT_TYPES.contains(argumentType.getSqlTypeName())) { - return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.BIGINT, argumentType.isNullable()); - } else { - return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, argumentType.isNullable()); - } + return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, argumentType.isNullable()); } @Override diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java index 89b09872d40..a7a5222d888 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java @@ -29,9 +29,10 @@ import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; -import org.apache.druid.query.aggregation.LongSumAggregatorFactory; +import org.apache.druid.query.aggregation.any.DoubleAnyAggregatorFactory; import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory; import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory; import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator; @@ -127,7 +128,7 @@ public class CalciteCorrelatedQueryTest extends BaseCalciteQueryTest .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) .setDimensions(new DefaultDimensionSpec("d1", "_d0")) .setAggregatorSpecs( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), useDefault ? new CountAggregatorFactory("_a0:count") : new FilteredAggregatorFactory( @@ -158,15 +159,15 @@ public class CalciteCorrelatedQueryTest extends BaseCalciteQueryTest ) .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) .setDimensions(new DefaultDimensionSpec("country", "d0")) - .setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0")) + .setAggregatorSpecs(new DoubleAnyAggregatorFactory("a0", "j0._a0")) .setGranularity(new AllGranularity()) .setContext(queryContext) .build() ), ImmutableList.of( - new Object[]{"India", 2L}, - new Object[]{"USA", 1L}, - new Object[]{"canada", 3L} + new Object[]{"India", 2.0}, + new Object[]{"USA", 1.0}, + new Object[]{"canada", 3.0} ) ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java index 72687eb3a19..a1438824b40 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java @@ -221,7 +221,7 @@ public class CalciteParameterQueryTest extends BaseCalciteQueryTest + "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", ImmutableList.of(), ImmutableList.of( - new Object[]{8L, 1249L, 156L, -5L, 1111L} + new Object[]{8L, 1249L, 156.125, -5L, 1111L} ), ImmutableList.of( new SqlParameter(SqlType.VARCHAR, "druid"), diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index da3e3f21b09..4042def2750 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -374,7 +374,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest + "WHERE TABLE_SCHEMA = 'druid' AND TABLE_NAME = 'foo'", ImmutableList.of(), ImmutableList.of( - new Object[]{8L, 1249L, 156L, -5L, 1111L} + new Object[]{8L, 1249L, 156.125, -5L, 1111L} ) ); } @@ -4942,7 +4942,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest new CountAggregatorFactory("a1"), notNull("dim1") ), - new LongSumAggregatorFactory("a2:sum", "cnt"), + new DoubleSumAggregatorFactory("a2:sum", "cnt"), new CountAggregatorFactory("a2:count"), new LongSumAggregatorFactory("a3", "cnt"), new LongMinAggregatorFactory("a4", "cnt"), @@ -4964,7 +4964,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest new CountAggregatorFactory("a2"), notNull("dim1") ), - new LongSumAggregatorFactory("a3:sum", "cnt"), + new DoubleSumAggregatorFactory("a3:sum", "cnt"), new FilteredAggregatorFactory( new CountAggregatorFactory("a3:count"), notNull("cnt") @@ -5014,10 +5014,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ), NullHandling.replaceWithDefault() ? ImmutableList.of( - new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)} + new Object[]{6L, 6L, 5L, 1.0, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)} ) : ImmutableList.of( - new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)} + new Object[]{6L, 6L, 6L, 1.0, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)} ) ); } @@ -7429,11 +7429,11 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .setAggregatorSpecs( useDefault ? aggregators( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new CountAggregatorFactory("_a0:count") ) : aggregators( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new FilteredAggregatorFactory( new CountAggregatorFactory("_a0:count"), notNull("a0") @@ -7455,7 +7455,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .setContext(QUERY_CONTEXT_DEFAULT) .build() ), - ImmutableList.of(new Object[]{1L}) + ImmutableList.of(new Object[]{1.0}) ); } @@ -9641,7 +9641,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest new LongSumAggregatorFactory("a6", "l1"), new LongMaxAggregatorFactory("a7", "l1"), new LongMinAggregatorFactory("a8", "l1"), - new LongSumAggregatorFactory("a9:sum", "l1"), + new DoubleSumAggregatorFactory("a9:sum", "l1"), useDefault ? new CountAggregatorFactory("a9:count") : new FilteredAggregatorFactory( @@ -9690,7 +9690,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest 0L, Long.MIN_VALUE, Long.MAX_VALUE, - 0L, + Double.NaN, Double.NaN } : new Object[]{0L, 0L, 0L, null, null, null, null, null, null, null, null} @@ -9936,7 +9936,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest equality("dim1", "nonexistent", ColumnType.STRING) ), new FilteredAggregatorFactory( - new LongSumAggregatorFactory("a9:sum", "l1"), + new DoubleSumAggregatorFactory("a9:sum", "l1"), equality("dim1", "nonexistent", ColumnType.STRING) ), useDefault @@ -10005,7 +10005,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest 0L, Long.MIN_VALUE, Long.MAX_VALUE, - 0L, + Double.NaN, Double.NaN } : new Object[]{"a", 0L, 0L, 0L, null, null, null, null, null, null, null, null} @@ -13147,7 +13147,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest new CountAggregatorFactory("a0"), notNull("v0") ), - new LongSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE), + new DoubleSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE), new CountAggregatorFactory("a1:count") ); virtualColumns = ImmutableList.of( @@ -13160,7 +13160,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest new CountAggregatorFactory("a0"), notNull("v0") ), - new LongSumAggregatorFactory("a1:sum", "v1"), + new DoubleSumAggregatorFactory("a1:sum", "v1"), new FilteredAggregatorFactory( new CountAggregatorFactory("a1:count"), notNull("v1") @@ -13204,7 +13204,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .build() ), ImmutableList.of( - new Object[]{"ab", 1L, 325323L} + new Object[]{"ab", 1L, 325323.0} ) ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSubqueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSubqueryTest.java index 2ddc674eadd..fb4c61b8cec 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSubqueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSubqueryTest.java @@ -34,6 +34,7 @@ import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.ResourceLimitExceededException; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; import org.apache.druid.query.aggregation.LongMinAggregatorFactory; @@ -558,14 +559,14 @@ public class CalciteSubqueryTest extends BaseCalciteQueryTest aggregators( new LongMaxAggregatorFactory("_a0", "a0"), new LongMinAggregatorFactory("_a1", "a0"), - new LongSumAggregatorFactory("_a2:sum", "a0"), + new DoubleSumAggregatorFactory("_a2:sum", "a0"), new CountAggregatorFactory("_a2:count"), new LongMaxAggregatorFactory("_a3", "d0"), new CountAggregatorFactory("_a4") ) : aggregators( new LongMaxAggregatorFactory("_a0", "a0"), new LongMinAggregatorFactory("_a1", "a0"), - new LongSumAggregatorFactory("_a2:sum", "a0"), + new DoubleSumAggregatorFactory("_a2:sum", "a0"), new FilteredAggregatorFactory( new CountAggregatorFactory("_a2:count"), notNull("a0") @@ -590,7 +591,7 @@ public class CalciteSubqueryTest extends BaseCalciteQueryTest .setContext(queryContext) .build() ), - ImmutableList.of(new Object[]{1L, 1L, 1L, 978480000L, 6L}) + ImmutableList.of(new Object[]{1L, 1L, 1.0, 978480000L, 6L}) ); }