Change type of AVG aggregates to double (#15089)

The sql standard is not very restrictive regarding this:

If AVG is specified and DT is exact numeric, then the declared type of the result is an implemen-
tation-defined exact numeric type with precision not less than the precision of DT and scale not
less than the scale of DT.

so; using the same type is also ok (without patch);
however the avg of 0 and 1 is 0 right now because of the retention of the integer typ

Postgres,MySql and Oracle and Drill seem to increase precision ; mssql returns 0
http://sqlfiddle.com/#!9/6f7248/1

I think we should also increase precision as its already calculated more precisely
This commit is contained in:
Zoltan Haindrich 2023-10-07 14:31:09 +02:00 committed by GitHub
parent 57ab8e13dc
commit 7b869fd37a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 56 additions and 58 deletions

View File

@ -369,7 +369,7 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
1L
1.0
}
);
@ -429,11 +429,11 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest
.setAggregatorSpecs(
NullHandling.replaceWithDefault()
? Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
: Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
notNull("a0")

View File

@ -278,7 +278,7 @@ public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
1L
1.0
}
);
@ -334,11 +334,11 @@ public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest
.setAggregatorSpecs(
NullHandling.replaceWithDefault()
? Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
: Arrays.asList(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
notNull("a0")

View File

@ -30,6 +30,7 @@ import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
@ -60,17 +61,17 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
private static final String STDDEV_NAME = "STDDEV";
private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(VARIANCE_NAME);
buildSqlVarianceAggFunction(VARIANCE_NAME);
private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.VAR_POP.name());
buildSqlVarianceAggFunction(SqlKind.VAR_POP.name());
private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name());
buildSqlVarianceAggFunction(SqlKind.VAR_SAMP.name());
private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(STDDEV_NAME);
buildSqlVarianceAggFunction(STDDEV_NAME);
private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name());
buildSqlVarianceAggFunction(SqlKind.STDDEV_POP.name());
private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name());
buildSqlVarianceAggFunction(SqlKind.STDDEV_SAMP.name());
@Nullable
@Override
@ -160,14 +161,15 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
}
/**
* Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction}
* but with an operand type that accepts variance aggregator objects in addition to numeric inputs.
* Creates a {@link SqlAggFunction}
*
* It accepts variance aggregator objects in addition to numeric inputs.
*/
private static SqlAggFunction buildSqlAvgAggFunction(String name)
private static SqlAggFunction buildSqlVarianceAggFunction(String name)
{
return OperatorConversions
.aggregatorBuilder(name)
.returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION)
.returnTypeInference(ReturnTypes.explicit(SqlTypeName.DOUBLE))
.operandTypeChecker(
OperandTypes.or(
OperandTypes.NUMERIC,

View File

@ -171,8 +171,8 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
holder1.getVariance(true),
holder2.getVariance(true).doubleValue(),
holder3.getVariance(true).longValue()
holder2.getVariance(true),
holder3.getVariance(true)
}
);
testQuery(
@ -219,7 +219,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
new Object[] {
holder1.getVariance(false),
holder2.getVariance(false).doubleValue(),
holder3.getVariance(false).longValue(),
holder3.getVariance(false),
}
);
testQuery(
@ -266,7 +266,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
new Object[] {
Math.sqrt(holder1.getVariance(true)),
Math.sqrt(holder2.getVariance(true)),
(long) Math.sqrt(holder3.getVariance(true)),
Math.sqrt(holder3.getVariance(true)),
}
);
@ -321,7 +321,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
new Object[]{
Math.sqrt(holder1.getVariance(false)),
Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
Math.sqrt(holder3.getVariance(false)),
}
);
@ -374,7 +374,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
new Object[]{
Math.sqrt(holder1.getVariance(false)),
Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
Math.sqrt(holder3.getVariance(false)),
}
);
@ -543,7 +543,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
),
ImmutableList.of(
NullHandling.replaceWithDefault()
? new Object[]{0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L}
? new Object[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}
: new Object[]{null, null, null, null, null, null, null, null}
)
);
@ -623,7 +623,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
),
ImmutableList.of(
NullHandling.replaceWithDefault()
? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L}
? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}
: new Object[]{"a", null, null, null, null, null, null, null, null}
)
);
@ -688,9 +688,9 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
Assert.assertEquals(expectedResult.length, result.length);
for (int j = 0; j < expectedResult.length; j++) {
if (expectedResult[j] instanceof Float) {
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10);
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-5);
} else if (expectedResult[j] instanceof Double) {
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10);
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-5);
} else {
Assert.assertEquals(expectedResult[j], result[j]);
}

View File

@ -124,13 +124,7 @@ public class DruidTypeSystem implements RelDataTypeSystem
final RelDataType argumentType
)
{
// Widen all averages to 64-bits regardless of the size of the inputs.
if (SqlTypeName.INT_TYPES.contains(argumentType.getSqlTypeName())) {
return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.BIGINT, argumentType.isNullable());
} else {
return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, argumentType.isNullable());
}
return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, argumentType.isNullable());
}
@Override

View File

@ -29,9 +29,10 @@ import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.any.DoubleAnyAggregatorFactory;
import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory;
import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
@ -127,7 +128,7 @@ public class CalciteCorrelatedQueryTest extends BaseCalciteQueryTest
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
useDefault
? new CountAggregatorFactory("_a0:count")
: new FilteredAggregatorFactory(
@ -158,15 +159,15 @@ public class CalciteCorrelatedQueryTest extends BaseCalciteQueryTest
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setAggregatorSpecs(new DoubleAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"India", 2L},
new Object[]{"USA", 1L},
new Object[]{"canada", 3L}
new Object[]{"India", 2.0},
new Object[]{"USA", 1.0},
new Object[]{"canada", 3.0}
)
);
}

View File

@ -221,7 +221,7 @@ public class CalciteParameterQueryTest extends BaseCalciteQueryTest
+ "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?",
ImmutableList.of(),
ImmutableList.of(
new Object[]{8L, 1249L, 156L, -5L, 1111L}
new Object[]{8L, 1249L, 156.125, -5L, 1111L}
),
ImmutableList.of(
new SqlParameter(SqlType.VARCHAR, "druid"),

View File

@ -374,7 +374,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
+ "WHERE TABLE_SCHEMA = 'druid' AND TABLE_NAME = 'foo'",
ImmutableList.of(),
ImmutableList.of(
new Object[]{8L, 1249L, 156L, -5L, 1111L}
new Object[]{8L, 1249L, 156.125, -5L, 1111L}
)
);
}
@ -4942,7 +4942,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new CountAggregatorFactory("a1"),
notNull("dim1")
),
new LongSumAggregatorFactory("a2:sum", "cnt"),
new DoubleSumAggregatorFactory("a2:sum", "cnt"),
new CountAggregatorFactory("a2:count"),
new LongSumAggregatorFactory("a3", "cnt"),
new LongMinAggregatorFactory("a4", "cnt"),
@ -4964,7 +4964,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new CountAggregatorFactory("a2"),
notNull("dim1")
),
new LongSumAggregatorFactory("a3:sum", "cnt"),
new DoubleSumAggregatorFactory("a3:sum", "cnt"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a3:count"),
notNull("cnt")
@ -5014,10 +5014,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
),
NullHandling.replaceWithDefault() ?
ImmutableList.of(
new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)}
new Object[]{6L, 6L, 5L, 1.0, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)}
) :
ImmutableList.of(
new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)}
new Object[]{6L, 6L, 6L, 1.0, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)}
)
);
}
@ -7429,11 +7429,11 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setAggregatorSpecs(
useDefault
? aggregators(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
: aggregators(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a0:count"),
notNull("a0")
@ -7455,7 +7455,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{1L})
ImmutableList.of(new Object[]{1.0})
);
}
@ -9641,7 +9641,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new LongSumAggregatorFactory("a6", "l1"),
new LongMaxAggregatorFactory("a7", "l1"),
new LongMinAggregatorFactory("a8", "l1"),
new LongSumAggregatorFactory("a9:sum", "l1"),
new DoubleSumAggregatorFactory("a9:sum", "l1"),
useDefault
? new CountAggregatorFactory("a9:count")
: new FilteredAggregatorFactory(
@ -9690,7 +9690,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
0L,
Long.MIN_VALUE,
Long.MAX_VALUE,
0L,
Double.NaN,
Double.NaN
}
: new Object[]{0L, 0L, 0L, null, null, null, null, null, null, null, null}
@ -9936,7 +9936,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
equality("dim1", "nonexistent", ColumnType.STRING)
),
new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a9:sum", "l1"),
new DoubleSumAggregatorFactory("a9:sum", "l1"),
equality("dim1", "nonexistent", ColumnType.STRING)
),
useDefault
@ -10005,7 +10005,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
0L,
Long.MIN_VALUE,
Long.MAX_VALUE,
0L,
Double.NaN,
Double.NaN
}
: new Object[]{"a", 0L, 0L, 0L, null, null, null, null, null, null, null, null}
@ -13147,7 +13147,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new CountAggregatorFactory("a0"),
notNull("v0")
),
new LongSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE),
new DoubleSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE),
new CountAggregatorFactory("a1:count")
);
virtualColumns = ImmutableList.of(
@ -13160,7 +13160,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new CountAggregatorFactory("a0"),
notNull("v0")
),
new LongSumAggregatorFactory("a1:sum", "v1"),
new DoubleSumAggregatorFactory("a1:sum", "v1"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a1:count"),
notNull("v1")
@ -13204,7 +13204,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.build()
),
ImmutableList.of(
new Object[]{"ab", 1L, 325323L}
new Object[]{"ab", 1L, 325323.0}
)
);
}

View File

@ -34,6 +34,7 @@ import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.ResourceLimitExceededException;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongMinAggregatorFactory;
@ -558,14 +559,14 @@ public class CalciteSubqueryTest extends BaseCalciteQueryTest
aggregators(
new LongMaxAggregatorFactory("_a0", "a0"),
new LongMinAggregatorFactory("_a1", "a0"),
new LongSumAggregatorFactory("_a2:sum", "a0"),
new DoubleSumAggregatorFactory("_a2:sum", "a0"),
new CountAggregatorFactory("_a2:count"),
new LongMaxAggregatorFactory("_a3", "d0"),
new CountAggregatorFactory("_a4")
) : aggregators(
new LongMaxAggregatorFactory("_a0", "a0"),
new LongMinAggregatorFactory("_a1", "a0"),
new LongSumAggregatorFactory("_a2:sum", "a0"),
new DoubleSumAggregatorFactory("_a2:sum", "a0"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a2:count"),
notNull("a0")
@ -590,7 +591,7 @@ public class CalciteSubqueryTest extends BaseCalciteQueryTest
.setContext(queryContext)
.build()
),
ImmutableList.of(new Object[]{1L, 1L, 1L, 978480000L, 6L})
ImmutableList.of(new Object[]{1L, 1L, 1.0, 978480000L, 6L})
);
}