mirror of https://github.com/apache/druid.git
Change type of AVG aggregates to double (#15089)
The sql standard is not very restrictive regarding this: If AVG is specified and DT is exact numeric, then the declared type of the result is an implemen- tation-defined exact numeric type with precision not less than the precision of DT and scale not less than the scale of DT. so; using the same type is also ok (without patch); however the avg of 0 and 1 is 0 right now because of the retention of the integer typ Postgres,MySql and Oracle and Drill seem to increase precision ; mssql returns 0 http://sqlfiddle.com/#!9/6f7248/1 I think we should also increase precision as its already calculated more precisely
This commit is contained in:
parent
57ab8e13dc
commit
7b869fd37a
|
@ -369,7 +369,7 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
|
||||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
1L
|
||||
1.0
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -429,11 +429,11 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
.setAggregatorSpecs(
|
||||
NullHandling.replaceWithDefault()
|
||||
? Arrays.asList(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new CountAggregatorFactory("_a0:count")
|
||||
)
|
||||
: Arrays.asList(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("_a0:count"),
|
||||
notNull("a0")
|
||||
|
|
|
@ -278,7 +278,7 @@ public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
|
||||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
1L
|
||||
1.0
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -334,11 +334,11 @@ public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
.setAggregatorSpecs(
|
||||
NullHandling.replaceWithDefault()
|
||||
? Arrays.asList(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new CountAggregatorFactory("_a0:count")
|
||||
)
|
||||
: Arrays.asList(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("_a0:count"),
|
||||
notNull("a0")
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.calcite.sql.SqlFunctionCategory;
|
|||
import org.apache.calcite.sql.SqlKind;
|
||||
import org.apache.calcite.sql.type.OperandTypes;
|
||||
import org.apache.calcite.sql.type.ReturnTypes;
|
||||
import org.apache.calcite.sql.type.SqlTypeName;
|
||||
import org.apache.druid.java.util.common.IAE;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
|
@ -60,17 +61,17 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
private static final String STDDEV_NAME = "STDDEV";
|
||||
|
||||
private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(VARIANCE_NAME);
|
||||
buildSqlVarianceAggFunction(VARIANCE_NAME);
|
||||
private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(SqlKind.VAR_POP.name());
|
||||
buildSqlVarianceAggFunction(SqlKind.VAR_POP.name());
|
||||
private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name());
|
||||
buildSqlVarianceAggFunction(SqlKind.VAR_SAMP.name());
|
||||
private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(STDDEV_NAME);
|
||||
buildSqlVarianceAggFunction(STDDEV_NAME);
|
||||
private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name());
|
||||
buildSqlVarianceAggFunction(SqlKind.STDDEV_POP.name());
|
||||
private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE =
|
||||
buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name());
|
||||
buildSqlVarianceAggFunction(SqlKind.STDDEV_SAMP.name());
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
|
@ -160,14 +161,15 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
|
|||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction}
|
||||
* but with an operand type that accepts variance aggregator objects in addition to numeric inputs.
|
||||
* Creates a {@link SqlAggFunction}
|
||||
*
|
||||
* It accepts variance aggregator objects in addition to numeric inputs.
|
||||
*/
|
||||
private static SqlAggFunction buildSqlAvgAggFunction(String name)
|
||||
private static SqlAggFunction buildSqlVarianceAggFunction(String name)
|
||||
{
|
||||
return OperatorConversions
|
||||
.aggregatorBuilder(name)
|
||||
.returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION)
|
||||
.returnTypeInference(ReturnTypes.explicit(SqlTypeName.DOUBLE))
|
||||
.operandTypeChecker(
|
||||
OperandTypes.or(
|
||||
OperandTypes.NUMERIC,
|
||||
|
|
|
@ -171,8 +171,8 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
holder1.getVariance(true),
|
||||
holder2.getVariance(true).doubleValue(),
|
||||
holder3.getVariance(true).longValue()
|
||||
holder2.getVariance(true),
|
||||
holder3.getVariance(true)
|
||||
}
|
||||
);
|
||||
testQuery(
|
||||
|
@ -219,7 +219,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
new Object[] {
|
||||
holder1.getVariance(false),
|
||||
holder2.getVariance(false).doubleValue(),
|
||||
holder3.getVariance(false).longValue(),
|
||||
holder3.getVariance(false),
|
||||
}
|
||||
);
|
||||
testQuery(
|
||||
|
@ -266,7 +266,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
new Object[] {
|
||||
Math.sqrt(holder1.getVariance(true)),
|
||||
Math.sqrt(holder2.getVariance(true)),
|
||||
(long) Math.sqrt(holder3.getVariance(true)),
|
||||
Math.sqrt(holder3.getVariance(true)),
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -321,7 +321,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
new Object[]{
|
||||
Math.sqrt(holder1.getVariance(false)),
|
||||
Math.sqrt(holder2.getVariance(false)),
|
||||
(long) Math.sqrt(holder3.getVariance(false)),
|
||||
Math.sqrt(holder3.getVariance(false)),
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -374,7 +374,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
new Object[]{
|
||||
Math.sqrt(holder1.getVariance(false)),
|
||||
Math.sqrt(holder2.getVariance(false)),
|
||||
(long) Math.sqrt(holder3.getVariance(false)),
|
||||
Math.sqrt(holder3.getVariance(false)),
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -543,7 +543,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
),
|
||||
ImmutableList.of(
|
||||
NullHandling.replaceWithDefault()
|
||||
? new Object[]{0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L}
|
||||
? new Object[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}
|
||||
: new Object[]{null, null, null, null, null, null, null, null}
|
||||
)
|
||||
);
|
||||
|
@ -623,7 +623,7 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
),
|
||||
ImmutableList.of(
|
||||
NullHandling.replaceWithDefault()
|
||||
? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L}
|
||||
? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}
|
||||
: new Object[]{"a", null, null, null, null, null, null, null, null}
|
||||
)
|
||||
);
|
||||
|
@ -688,9 +688,9 @@ public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
|
|||
Assert.assertEquals(expectedResult.length, result.length);
|
||||
for (int j = 0; j < expectedResult.length; j++) {
|
||||
if (expectedResult[j] instanceof Float) {
|
||||
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10);
|
||||
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-5);
|
||||
} else if (expectedResult[j] instanceof Double) {
|
||||
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10);
|
||||
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-5);
|
||||
} else {
|
||||
Assert.assertEquals(expectedResult[j], result[j]);
|
||||
}
|
||||
|
|
|
@ -124,14 +124,8 @@ public class DruidTypeSystem implements RelDataTypeSystem
|
|||
final RelDataType argumentType
|
||||
)
|
||||
{
|
||||
// Widen all averages to 64-bits regardless of the size of the inputs.
|
||||
|
||||
if (SqlTypeName.INT_TYPES.contains(argumentType.getSqlTypeName())) {
|
||||
return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.BIGINT, argumentType.isNullable());
|
||||
} else {
|
||||
return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, argumentType.isNullable());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public RelDataType deriveCovarType(
|
||||
|
|
|
@ -29,9 +29,10 @@ import org.apache.druid.java.util.common.granularity.Granularity;
|
|||
import org.apache.druid.query.QueryDataSource;
|
||||
import org.apache.druid.query.TableDataSource;
|
||||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.any.DoubleAnyAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
|
||||
|
@ -127,7 +128,7 @@ public class CalciteCorrelatedQueryTest extends BaseCalciteQueryTest
|
|||
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
|
||||
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
|
||||
.setAggregatorSpecs(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
|
||||
useDefault
|
||||
? new CountAggregatorFactory("_a0:count")
|
||||
: new FilteredAggregatorFactory(
|
||||
|
@ -158,15 +159,15 @@ public class CalciteCorrelatedQueryTest extends BaseCalciteQueryTest
|
|||
)
|
||||
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
|
||||
.setDimensions(new DefaultDimensionSpec("country", "d0"))
|
||||
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
|
||||
.setAggregatorSpecs(new DoubleAnyAggregatorFactory("a0", "j0._a0"))
|
||||
.setGranularity(new AllGranularity())
|
||||
.setContext(queryContext)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{"India", 2L},
|
||||
new Object[]{"USA", 1L},
|
||||
new Object[]{"canada", 3L}
|
||||
new Object[]{"India", 2.0},
|
||||
new Object[]{"USA", 1.0},
|
||||
new Object[]{"canada", 3.0}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -221,7 +221,7 @@ public class CalciteParameterQueryTest extends BaseCalciteQueryTest
|
|||
+ "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?",
|
||||
ImmutableList.of(),
|
||||
ImmutableList.of(
|
||||
new Object[]{8L, 1249L, 156L, -5L, 1111L}
|
||||
new Object[]{8L, 1249L, 156.125, -5L, 1111L}
|
||||
),
|
||||
ImmutableList.of(
|
||||
new SqlParameter(SqlType.VARCHAR, "druid"),
|
||||
|
|
|
@ -374,7 +374,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
+ "WHERE TABLE_SCHEMA = 'druid' AND TABLE_NAME = 'foo'",
|
||||
ImmutableList.of(),
|
||||
ImmutableList.of(
|
||||
new Object[]{8L, 1249L, 156L, -5L, 1111L}
|
||||
new Object[]{8L, 1249L, 156.125, -5L, 1111L}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -4942,7 +4942,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new CountAggregatorFactory("a1"),
|
||||
notNull("dim1")
|
||||
),
|
||||
new LongSumAggregatorFactory("a2:sum", "cnt"),
|
||||
new DoubleSumAggregatorFactory("a2:sum", "cnt"),
|
||||
new CountAggregatorFactory("a2:count"),
|
||||
new LongSumAggregatorFactory("a3", "cnt"),
|
||||
new LongMinAggregatorFactory("a4", "cnt"),
|
||||
|
@ -4964,7 +4964,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new CountAggregatorFactory("a2"),
|
||||
notNull("dim1")
|
||||
),
|
||||
new LongSumAggregatorFactory("a3:sum", "cnt"),
|
||||
new DoubleSumAggregatorFactory("a3:sum", "cnt"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a3:count"),
|
||||
notNull("cnt")
|
||||
|
@ -5014,10 +5014,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
),
|
||||
NullHandling.replaceWithDefault() ?
|
||||
ImmutableList.of(
|
||||
new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)}
|
||||
new Object[]{6L, 6L, 5L, 1.0, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)}
|
||||
) :
|
||||
ImmutableList.of(
|
||||
new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)}
|
||||
new Object[]{6L, 6L, 6L, 1.0, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -7429,11 +7429,11 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.setAggregatorSpecs(
|
||||
useDefault
|
||||
? aggregators(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new CountAggregatorFactory("_a0:count")
|
||||
)
|
||||
: aggregators(
|
||||
new LongSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("_a0:count"),
|
||||
notNull("a0")
|
||||
|
@ -7455,7 +7455,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(new Object[]{1L})
|
||||
ImmutableList.of(new Object[]{1.0})
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -9641,7 +9641,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new LongSumAggregatorFactory("a6", "l1"),
|
||||
new LongMaxAggregatorFactory("a7", "l1"),
|
||||
new LongMinAggregatorFactory("a8", "l1"),
|
||||
new LongSumAggregatorFactory("a9:sum", "l1"),
|
||||
new DoubleSumAggregatorFactory("a9:sum", "l1"),
|
||||
useDefault
|
||||
? new CountAggregatorFactory("a9:count")
|
||||
: new FilteredAggregatorFactory(
|
||||
|
@ -9690,7 +9690,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
0L,
|
||||
Long.MIN_VALUE,
|
||||
Long.MAX_VALUE,
|
||||
0L,
|
||||
Double.NaN,
|
||||
Double.NaN
|
||||
}
|
||||
: new Object[]{0L, 0L, 0L, null, null, null, null, null, null, null, null}
|
||||
|
@ -9936,7 +9936,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
equality("dim1", "nonexistent", ColumnType.STRING)
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new LongSumAggregatorFactory("a9:sum", "l1"),
|
||||
new DoubleSumAggregatorFactory("a9:sum", "l1"),
|
||||
equality("dim1", "nonexistent", ColumnType.STRING)
|
||||
),
|
||||
useDefault
|
||||
|
@ -10005,7 +10005,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
0L,
|
||||
Long.MIN_VALUE,
|
||||
Long.MAX_VALUE,
|
||||
0L,
|
||||
Double.NaN,
|
||||
Double.NaN
|
||||
}
|
||||
: new Object[]{"a", 0L, 0L, 0L, null, null, null, null, null, null, null, null}
|
||||
|
@ -13147,7 +13147,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new CountAggregatorFactory("a0"),
|
||||
notNull("v0")
|
||||
),
|
||||
new LongSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE),
|
||||
new DoubleSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE),
|
||||
new CountAggregatorFactory("a1:count")
|
||||
);
|
||||
virtualColumns = ImmutableList.of(
|
||||
|
@ -13160,7 +13160,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new CountAggregatorFactory("a0"),
|
||||
notNull("v0")
|
||||
),
|
||||
new LongSumAggregatorFactory("a1:sum", "v1"),
|
||||
new DoubleSumAggregatorFactory("a1:sum", "v1"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("a1:count"),
|
||||
notNull("v1")
|
||||
|
@ -13204,7 +13204,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{"ab", 1L, 325323L}
|
||||
new Object[]{"ab", 1L, 325323.0}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.apache.druid.query.QueryDataSource;
|
|||
import org.apache.druid.query.ResourceLimitExceededException;
|
||||
import org.apache.druid.query.TableDataSource;
|
||||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.LongMinAggregatorFactory;
|
||||
|
@ -558,14 +559,14 @@ public class CalciteSubqueryTest extends BaseCalciteQueryTest
|
|||
aggregators(
|
||||
new LongMaxAggregatorFactory("_a0", "a0"),
|
||||
new LongMinAggregatorFactory("_a1", "a0"),
|
||||
new LongSumAggregatorFactory("_a2:sum", "a0"),
|
||||
new DoubleSumAggregatorFactory("_a2:sum", "a0"),
|
||||
new CountAggregatorFactory("_a2:count"),
|
||||
new LongMaxAggregatorFactory("_a3", "d0"),
|
||||
new CountAggregatorFactory("_a4")
|
||||
) : aggregators(
|
||||
new LongMaxAggregatorFactory("_a0", "a0"),
|
||||
new LongMinAggregatorFactory("_a1", "a0"),
|
||||
new LongSumAggregatorFactory("_a2:sum", "a0"),
|
||||
new DoubleSumAggregatorFactory("_a2:sum", "a0"),
|
||||
new FilteredAggregatorFactory(
|
||||
new CountAggregatorFactory("_a2:count"),
|
||||
notNull("a0")
|
||||
|
@ -590,7 +591,7 @@ public class CalciteSubqueryTest extends BaseCalciteQueryTest
|
|||
.setContext(queryContext)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(new Object[]{1L, 1L, 1L, 978480000L, 6L})
|
||||
ImmutableList.of(new Object[]{1L, 1L, 1.0, 978480000L, 6L})
|
||||
);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue