From 3e15522d6b3200540841f456d918921811fff328 Mon Sep 17 00:00:00 2001 From: Soumyava <93540295+somu-imply@users.noreply.github.com> Date: Wed, 13 Dec 2023 17:23:14 -0800 Subject: [PATCH] Round works correctly on system metadata columns (#15554) --- .../org/apache/druid/math/expr/Function.java | 3 +- .../builtin/RoundOperatorConversion.java | 38 ++++------------ .../sql/calcite/CalciteSysQueryTest.java | 44 +++++++++++++++++-- 3 files changed, 51 insertions(+), 34 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java b/processing/src/main/java/org/apache/druid/math/expr/Function.java index 3f90f7c3999..01c317fe098 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Function.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java @@ -1436,11 +1436,12 @@ public interface Function extends NamedFunction private static final BigDecimal MAX_FINITE_VALUE = BigDecimal.valueOf(Double.MAX_VALUE); private static final BigDecimal MIN_FINITE_VALUE = BigDecimal.valueOf(-1 * Double.MAX_VALUE); //CHECKSTYLE.ON: Regexp + public static final String NAME = "round"; @Override public String name() { - return "round"; + return NAME; } @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RoundOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RoundOperatorConversion.java index a14e3a0f0d9..e8031e55aee 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RoundOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RoundOperatorConversion.java @@ -19,41 +19,21 @@ package org.apache.druid.sql.calcite.expression.builtin; -import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlFunction; -import org.apache.calcite.sql.SqlFunctionCategory; -import org.apache.calcite.sql.type.ReturnTypes; -import org.apache.calcite.sql.type.SqlTypeFamily; -import org.apache.druid.segment.column.RowSignature; -import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.OperatorConversions; -import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; -import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.druid.math.expr.Function; +import org.apache.druid.sql.calcite.expression.DirectOperatorConversion; -public class RoundOperatorConversion implements SqlOperatorConversion +public class RoundOperatorConversion extends DirectOperatorConversion { - private static final SqlFunction SQL_FUNCTION = OperatorConversions - .operatorBuilder("ROUND") - .operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER) - .requiredOperandCount(1) - .returnTypeInference(ReturnTypes.ARG0) - .functionCategory(SqlFunctionCategory.NUMERIC) - .build(); + public RoundOperatorConversion() + { + super(SqlStdOperatorTable.ROUND, Function.Round.NAME); + } @Override public SqlFunction calciteOperator() { - return SQL_FUNCTION; - } - - @Override - public DruidExpression toDruidExpression(final PlannerContext plannerContext, final RowSignature rowSignature, final RexNode rexNode) - { - return OperatorConversions.convertDirectCall( - plannerContext, - rowSignature, - rexNode, - "round" - ); + return SqlStdOperatorTable.ROUND; } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSysQueryTest.java index 5b0a5dd82e3..7d3737677bc 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSysQueryTest.java @@ -41,10 +41,11 @@ public class CalciteSysQueryTest extends BaseCalciteQueryTest .sql("select datasource, sum(duration) from sys.tasks group by datasource") .expectedResults(ImmutableList.of( new Object[]{"foo", 11L}, - new Object[]{"foo2", 22L})) + new Object[]{"foo2", 22L} + )) .expectedLogicalPlan("LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])\n" - + " LogicalProject(exprs=[[$3, $8]])\n" - + " LogicalTableScan(table=[[sys, tasks]])\n") + + " LogicalProject(exprs=[[$3, $8]])\n" + + " LogicalTableScan(table=[[sys, tasks]])\n") .run(); } @@ -59,8 +60,43 @@ public class CalciteSysQueryTest extends BaseCalciteQueryTest .sql("select datasource, sum(duration) over () from sys.tasks group by datasource") .expectedResults(ImmutableList.of( new Object[]{"foo", 11L}, - new Object[]{"foo2", 22L})) + new Object[]{"foo2", 22L} + )) // please add expectedLogicalPlan if this test starts passing! .run(); } + + @Test + public void testRoundOnSysTableColumn() + { + msqIncompatible(); + + testBuilder() + .sql("select round(duration, 1) from sys.tasks ") + .expectedResults(ImmutableList.of( + new Object[]{10L}, + new Object[]{1L}, + new Object[]{20L}, + new Object[]{2L} + )) + .expectedLogicalPlan("LogicalProject(exprs=[[ROUND($8, 1)]])\n" + + " LogicalTableScan(table=[[sys, tasks]])\n") + .run(); + } + + @Test + public void testRoundOnAvgOnSysTableColumn() + { + msqIncompatible(); + + testBuilder() + .sql("select round(avg(duration), 1) from sys.tasks ") + .expectedResults(ImmutableList.of( + new Object[]{8.3D})) + .expectedLogicalPlan("LogicalProject(exprs=[[ROUND($0, 1)]])\n" + + " LogicalAggregate(group=[{}], agg#0=[AVG($0)])\n" + + " LogicalProject(exprs=[[$8]])\n" + + " LogicalTableScan(table=[[sys, tasks]])\n") + .run(); + } }