diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index 000b8939162..4c8cd93d8ba 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -26,6 +26,8 @@ import org.apache.druid.java.util.common.StringUtils; import org.joda.time.DateTime; import org.joda.time.format.DateTimeFormat; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.List; /** @@ -499,7 +501,7 @@ interface Function } } - class Round extends SingleParamMath + class Round implements Function { @Override public String name() @@ -508,9 +510,42 @@ interface Function } @Override - protected ExprEval eval(double param) + public ExprEval apply(List args, Expr.ObjectBinding bindings) { - return ExprEval.of(Math.round(param)); + if (args.size() != 1 && args.size() != 2) { + throw new IAE("Function[%s] needs 1 or 2 arguments", name()); + } + + ExprEval value1 = args.get(0).eval(bindings); + if (value1.type() != ExprType.LONG && value1.type() != ExprType.DOUBLE) { + throw new IAE("The first argument to the function[%s] should be integer or double type but get the %s type", name(), value1.type()); + } + + if (args.size() == 1) { + return eval(value1); + } else { + ExprEval value2 = args.get(1).eval(bindings); + if (value2.type() != ExprType.LONG) { + throw new IAE("The second argument to the function[%s] should be integer type but get the %s type", name(), value2.type()); + } + return eval(value1, value2.asInt()); + } + } + + private ExprEval eval(ExprEval param) + { + return eval(param, 0); + } + + private ExprEval eval(ExprEval param, int scale) + { + if (param.type() == ExprType.LONG) { + return ExprEval.of(BigDecimal.valueOf(param.asLong()).setScale(scale, RoundingMode.HALF_UP).longValue()); + } else if (param.type() == ExprType.DOUBLE) { + return ExprEval.of(BigDecimal.valueOf(param.asDouble()).setScale(scale, RoundingMode.HALF_UP).doubleValue()); + } else { + return ExprEval.of(null); + } } } diff --git a/docs/content/misc/math-expr.md b/docs/content/misc/math-expr.md index ce3389f1b57..b8b84097ef9 100644 --- a/docs/content/misc/math-expr.md +++ b/docs/content/misc/math-expr.md @@ -133,7 +133,7 @@ See javadoc of java.lang.Math for detailed explanation for each function. |pow|pow(x, y) would return the value of the x raised to the power of y| |remainder|remainder(x, y) would return the remainder operation on two arguments as prescribed by the IEEE 754 standard| |rint|rint(x) would return value that is closest in value to x and is equal to a mathematical integer| -|round|round(x) would return the closest long value to x, with ties rounding up| +|round|round(x, y) would return the value of the x rounded to the y decimal places. While x can be an integer or floating-point number, y must be an integer. The type of the return value is specified by that of x. y defaults to 0 if omitted. When y is negative, x is rounded on the left side of the y decimal points.| |scalb|scalb(d, sf) would return d * 2^sf rounded as if performed by a single correctly rounded floating-point multiply to a member of the double value set| |signum|signum(x) would return the signum function of the argument x| |sin|sin(x) would return the trigonometric sine of an angle x| diff --git a/docs/content/querying/sql.md b/docs/content/querying/sql.md index 5e7ef137d7f..d596c142b0b 100644 --- a/docs/content/querying/sql.md +++ b/docs/content/querying/sql.md @@ -149,6 +149,7 @@ Numeric functions will return 64 bit integers or 64 bit floats, depending on the |`SQRT(expr)`|Square root.| |`TRUNCATE(expr[, digits])`|Truncate expr to a specific number of decimal digits. If digits is negative, then this truncates that many places to the left of the decimal point. Digits defaults to zero if not specified.| |`TRUNC(expr[, digits])`|Synonym for `TRUNCATE`.| +|`ROUND(expr[, digits])`|`ROUND(x, y)` would return the value of the x rounded to the y decimal places. While x can be an integer or floating-point number, y must be an integer. The type of the return value is specified by that of x. y defaults to 0 if omitted. When y is negative, x is rounded on the left side of the y decimal points.| |`x + y`|Addition.| |`x - y`|Subtraction.| |`x * y`|Multiplication.| diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RoundOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RoundOperatorConversion.java new file mode 100644 index 00000000000..558a1b9cda0 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RoundOperatorConversion.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.expression.builtin; + +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.sql.calcite.expression.OperatorConversions; +import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.table.RowSignature; + +public class RoundOperatorConversion implements SqlOperatorConversion +{ + private static final SqlFunction SQL_FUNCTION = OperatorConversions + .operatorBuilder("ROUND") + .operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER) + .requiredOperands(1) + .returnTypeInference(ReturnTypes.ARG0) + .functionCategory(SqlFunctionCategory.NUMERIC) + .build(); + + @Override + public SqlFunction calciteOperator() + { + return SQL_FUNCTION; + } + + @Override + public DruidExpression toDruidExpression(final PlannerContext plannerContext, final RowSignature rowSignature, final RexNode rexNode) + { + return OperatorConversions.convertCall(plannerContext, rowSignature, rexNode, inputExpressions -> { + return DruidExpression.fromFunctionCall( + "round", + inputExpressions + ); + }); + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index c6759b5f021..82682f56549 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -65,6 +65,7 @@ import org.apache.druid.sql.calcite.expression.builtin.ReinterpretOperatorConver import org.apache.druid.sql.calcite.expression.builtin.RepeatOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.ReverseOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.RightOperatorConversion; +import org.apache.druid.sql.calcite.expression.builtin.RoundOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.StringFormatOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.StrposOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.SubstringOperatorConversion; @@ -159,6 +160,7 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new BinaryOperatorConversion(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, "<=")) .add(new BinaryOperatorConversion(SqlStdOperatorTable.AND, "&&")) .add(new BinaryOperatorConversion(SqlStdOperatorTable.OR, "||")) + .add(new RoundOperatorConversion()) // time operators .add(new CeilOperatorConversion()) .add(new DateTruncOperatorConversion()) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java index fa890a10379..ee6f616dd26 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java @@ -48,6 +48,7 @@ import org.apache.druid.sql.calcite.expression.builtin.RegexpExtractOperatorConv import org.apache.druid.sql.calcite.expression.builtin.RepeatOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.ReverseOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.RightOperatorConversion; +import org.apache.druid.sql.calcite.expression.builtin.RoundOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.StringFormatOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.StrposOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.TimeExtractOperatorConversion; @@ -463,6 +464,82 @@ public class ExpressionsTest extends CalciteTestBase ); } + @Test + public void testRound() + { + final SqlFunction roundFunction = new RoundOperatorConversion().calciteOperator(); + + testExpression( + rexBuilder.makeCall(roundFunction, inputRef("a")), + DruidExpression.fromExpression("round(\"a\")"), + 10L + ); + + testExpression( + rexBuilder.makeCall(roundFunction, inputRef("b")), + DruidExpression.fromExpression("round(\"b\")"), + 25L + ); + + testExpression( + rexBuilder.makeCall(roundFunction, inputRef("b"), integerLiteral(-1)), + DruidExpression.fromExpression("round(\"b\",-1)"), + 30L + ); + + testExpression( + rexBuilder.makeCall(roundFunction, inputRef("x")), + DruidExpression.fromExpression("round(\"x\")"), + 2.0 + ); + + testExpression( + rexBuilder.makeCall(roundFunction, inputRef("x"), integerLiteral(1)), + DruidExpression.fromExpression("round(\"x\",1)"), + 2.3 + ); + + testExpression( + rexBuilder.makeCall(roundFunction, inputRef("y")), + DruidExpression.fromExpression("round(\"y\")"), + 3.0 + ); + + testExpression( + rexBuilder.makeCall(roundFunction, inputRef("z")), + DruidExpression.fromExpression("round(\"z\")"), + -2.0 + ); + } + + @Test + public void testRoundWithInvalidArgument() + { + final SqlFunction roundFunction = new RoundOperatorConversion().calciteOperator(); + + expectedException.expect(IAE.class); + expectedException.expectMessage("The first argument to the function[round] should be integer or double type but get the STRING type"); + testExpression( + rexBuilder.makeCall(roundFunction, inputRef("s")), + DruidExpression.fromExpression("round(\"s\")"), + "IAE Exception" + ); + } + + @Test + public void testRoundWithInvalidSecondArgument() + { + final SqlFunction roundFunction = new RoundOperatorConversion().calciteOperator(); + + expectedException.expect(IAE.class); + expectedException.expectMessage("The second argument to the function[round] should be integer type but get the STRING type"); + testExpression( + rexBuilder.makeCall(roundFunction, inputRef("x"), rexBuilder.makeLiteral("foo")), + DruidExpression.fromExpression("round(\"x\",'foo')"), + "IAE Exception" + ); + } + @Test public void testDateTrunc() {