Round works correctly on system metadata columns (#15554)

This commit is contained in:
Soumyava 2023-12-13 17:23:14 -08:00 committed by GitHub
parent 81fe855b6f
commit 3e15522d6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 34 deletions

View File

@ -1436,11 +1436,12 @@ public interface Function extends NamedFunction
private static final BigDecimal MAX_FINITE_VALUE = BigDecimal.valueOf(Double.MAX_VALUE); private static final BigDecimal MAX_FINITE_VALUE = BigDecimal.valueOf(Double.MAX_VALUE);
private static final BigDecimal MIN_FINITE_VALUE = BigDecimal.valueOf(-1 * Double.MAX_VALUE); private static final BigDecimal MIN_FINITE_VALUE = BigDecimal.valueOf(-1 * Double.MAX_VALUE);
//CHECKSTYLE.ON: Regexp //CHECKSTYLE.ON: Regexp
public static final String NAME = "round";
@Override @Override
public String name() public String name()
{ {
return "round"; return NAME;
} }
@Override @Override

View File

@ -19,41 +19,21 @@
package org.apache.druid.sql.calcite.expression.builtin; package org.apache.druid.sql.calcite.expression.builtin;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.ReturnTypes; import org.apache.druid.math.expr.Function;
import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.druid.sql.calcite.expression.DirectOperatorConversion;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.expression.SqlOperatorConversion;
import org.apache.druid.sql.calcite.planner.PlannerContext;
public class RoundOperatorConversion implements SqlOperatorConversion public class RoundOperatorConversion extends DirectOperatorConversion
{ {
private static final SqlFunction SQL_FUNCTION = OperatorConversions public RoundOperatorConversion()
.operatorBuilder("ROUND") {
.operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER) super(SqlStdOperatorTable.ROUND, Function.Round.NAME);
.requiredOperandCount(1) }
.returnTypeInference(ReturnTypes.ARG0)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();
@Override @Override
public SqlFunction calciteOperator() public SqlFunction calciteOperator()
{ {
return SQL_FUNCTION; return SqlStdOperatorTable.ROUND;
}
@Override
public DruidExpression toDruidExpression(final PlannerContext plannerContext, final RowSignature rowSignature, final RexNode rexNode)
{
return OperatorConversions.convertDirectCall(
plannerContext,
rowSignature,
rexNode,
"round"
);
} }
} }

View File

@ -41,7 +41,8 @@ public class CalciteSysQueryTest extends BaseCalciteQueryTest
.sql("select datasource, sum(duration) from sys.tasks group by datasource") .sql("select datasource, sum(duration) from sys.tasks group by datasource")
.expectedResults(ImmutableList.of( .expectedResults(ImmutableList.of(
new Object[]{"foo", 11L}, new Object[]{"foo", 11L},
new Object[]{"foo2", 22L})) new Object[]{"foo2", 22L}
))
.expectedLogicalPlan("LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])\n" .expectedLogicalPlan("LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])\n"
+ " LogicalProject(exprs=[[$3, $8]])\n" + " LogicalProject(exprs=[[$3, $8]])\n"
+ " LogicalTableScan(table=[[sys, tasks]])\n") + " LogicalTableScan(table=[[sys, tasks]])\n")
@ -59,8 +60,43 @@ public class CalciteSysQueryTest extends BaseCalciteQueryTest
.sql("select datasource, sum(duration) over () from sys.tasks group by datasource") .sql("select datasource, sum(duration) over () from sys.tasks group by datasource")
.expectedResults(ImmutableList.of( .expectedResults(ImmutableList.of(
new Object[]{"foo", 11L}, new Object[]{"foo", 11L},
new Object[]{"foo2", 22L})) new Object[]{"foo2", 22L}
))
// please add expectedLogicalPlan if this test starts passing! // please add expectedLogicalPlan if this test starts passing!
.run(); .run();
} }
@Test
public void testRoundOnSysTableColumn()
{
msqIncompatible();
testBuilder()
.sql("select round(duration, 1) from sys.tasks ")
.expectedResults(ImmutableList.of(
new Object[]{10L},
new Object[]{1L},
new Object[]{20L},
new Object[]{2L}
))
.expectedLogicalPlan("LogicalProject(exprs=[[ROUND($8, 1)]])\n"
+ " LogicalTableScan(table=[[sys, tasks]])\n")
.run();
}
@Test
public void testRoundOnAvgOnSysTableColumn()
{
msqIncompatible();
testBuilder()
.sql("select round(avg(duration), 1) from sys.tasks ")
.expectedResults(ImmutableList.of(
new Object[]{8.3D}))
.expectedLogicalPlan("LogicalProject(exprs=[[ROUND($0, 1)]])\n"
+ " LogicalAggregate(group=[{}], agg#0=[AVG($0)])\n"
+ " LogicalProject(exprs=[[$8]])\n"
+ " LogicalTableScan(table=[[sys, tasks]])\n")
.run();
}
} }