Support to pass dynamic values to timestamp Extract function (#15586)

Fixes #15072

Before this modification , the third parameter (timezone) require to be a Literal, it will throw a error when this parameter is column Identifier.
This commit is contained in:
AlbericByte 2023-12-20 22:27:52 -08:00 committed by GitHub
parent 8a45efbf65
commit a2e65e6a89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 272 additions and 140 deletions

View File

@ -124,7 +124,6 @@ public interface NamedFunction
} }
} }
/** /**
* Helper method for implementors performing validation to check that the argument list is some expected size. * Helper method for implementors performing validation to check that the argument list is some expected size.
* *

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.expression; package org.apache.druid.query.expression;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprEval;
@ -64,6 +65,87 @@ public class TimestampExtractExprMacro implements ExprMacroTable.ExprMacro
return FN_NAME; return FN_NAME;
} }
private ExprEval getExprEval(final DateTime dateTime, final Unit unit)
{
long epoch = dateTime.getMillis() / 1000;
switch (unit) {
case EPOCH:
return ExprEval.of(epoch);
case MICROSECOND:
return ExprEval.of(epoch / 1000);
case MILLISECOND:
return ExprEval.of(dateTime.millisOfSecond().get());
case SECOND:
return ExprEval.of(dateTime.secondOfMinute().get());
case MINUTE:
return ExprEval.of(dateTime.minuteOfHour().get());
case HOUR:
return ExprEval.of(dateTime.hourOfDay().get());
case DAY:
return ExprEval.of(dateTime.dayOfMonth().get());
case DOW:
return ExprEval.of(dateTime.dayOfWeek().get());
case ISODOW:
return ExprEval.of(dateTime.dayOfWeek().get());
case DOY:
return ExprEval.of(dateTime.dayOfYear().get());
case WEEK:
return ExprEval.of(dateTime.weekOfWeekyear().get());
case MONTH:
return ExprEval.of(dateTime.monthOfYear().get());
case QUARTER:
return ExprEval.of((dateTime.monthOfYear().get() - 1) / 3 + 1);
case YEAR:
return ExprEval.of(dateTime.year().get());
case ISOYEAR:
return ExprEval.of(dateTime.year().get());
case DECADE:
// The year field divided by 10, See https://www.postgresql.org/docs/10/functions-datetime.html
return ExprEval.of(dateTime.year().get() / 10);
case CENTURY:
return ExprEval.of(Math.ceil((double) dateTime.year().get() / 100));
case MILLENNIUM:
// Years in the 1900s are in the second millennium. The third millennium started January 1, 2001.
// See https://www.postgresql.org/docs/10/functions-datetime.html
return ExprEval.of(Math.ceil((double) dateTime.year().get() / 1000));
default:
throw TimestampExtractExprMacro.this.validationFailed("unhandled unit[%s]", unit);
}
}
private static ExpressionType getOutputExpressionType(final Unit unit)
{
switch (unit) {
case CENTURY:
case MILLENNIUM:
return ExpressionType.DOUBLE;
default:
return ExpressionType.LONG;
}
}
private static String stringifyExpr(final List<Expr> args)
{
if (args.size() > 2) {
return StringUtils.format(
"%s(%s, %s, %s)",
FN_NAME,
args.get(0).stringify(),
args.get(1).stringify(),
args.get(2).stringify()
);
}
return StringUtils.format("%s(%s, %s)", FN_NAME, args.get(0).stringify(), args.get(1).stringify());
}
private static ISOChronology computeChronology(final List<Expr> args, final Expr.ObjectBinding bindings)
{
String timeZoneVal = (String) args.get(2).eval(bindings).value();
return timeZoneVal != null
? ISOChronology.getInstance(DateTimes.inferTzFromString(timeZoneVal))
: ISOChronology.getInstanceUTC();
}
@Override @Override
public Expr apply(final List<Expr> args) public Expr apply(final List<Expr> args)
{ {
@ -73,121 +155,106 @@ public class TimestampExtractExprMacro implements ExprMacroTable.ExprMacro
throw validationFailed("unit arg must be literal"); throw validationFailed("unit arg must be literal");
} }
if (args.size() > 2) {
validationHelperCheckArgIsLiteral(args.get(2), "timezone");
}
final Expr arg = args.get(0);
final Unit unit = Unit.valueOf(StringUtils.toUpperCase((String) args.get(1).getLiteralValue())); final Unit unit = Unit.valueOf(StringUtils.toUpperCase((String) args.get(1).getLiteralValue()));
final DateTimeZone timeZone;
if (args.size() > 2) { if (args.size() > 2) {
timeZone = ExprUtils.toTimeZone(args.get(2)); if (args.get(2).isLiteral()) {
} else { DateTimeZone timeZone = ExprUtils.toTimeZone(args.get(2));
timeZone = DateTimeZone.UTC; ISOChronology chronology = ISOChronology.getInstance(timeZone);
return new TimestampExtractExpr(args, unit, chronology);
} else {
return new TimestampExtractDynamicExpr(args, unit);
}
} }
return new TimestampExtractExpr(args, unit, ISOChronology.getInstanceUTC());
}
final ISOChronology chronology = ISOChronology.getInstance(timeZone); public class TimestampExtractExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
{
private final ISOChronology chronology;
private final Unit unit;
class TimestampExtractExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr private TimestampExtractExpr(final List<Expr> args, final Unit unit, final ISOChronology chronology)
{ {
private TimestampExtractExpr(Expr arg) super(FN_NAME, args);
{ this.unit = unit;
super(FN_NAME, arg); this.chronology = chronology;
}
@Nonnull
@Override
public ExprEval eval(final ObjectBinding bindings)
{
Object val = arg.eval(bindings).value();
if (val == null) {
// Return null if the argument if null.
return ExprEval.of(null);
}
final DateTime dateTime = new DateTime(val, chronology);
long epoch = dateTime.getMillis() / 1000;
switch (unit) {
case EPOCH:
return ExprEval.of(epoch);
case MICROSECOND:
return ExprEval.of(epoch / 1000);
case MILLISECOND:
return ExprEval.of(dateTime.millisOfSecond().get());
case SECOND:
return ExprEval.of(dateTime.secondOfMinute().get());
case MINUTE:
return ExprEval.of(dateTime.minuteOfHour().get());
case HOUR:
return ExprEval.of(dateTime.hourOfDay().get());
case DAY:
return ExprEval.of(dateTime.dayOfMonth().get());
case DOW:
return ExprEval.of(dateTime.dayOfWeek().get());
case ISODOW:
return ExprEval.of(dateTime.dayOfWeek().get());
case DOY:
return ExprEval.of(dateTime.dayOfYear().get());
case WEEK:
return ExprEval.of(dateTime.weekOfWeekyear().get());
case MONTH:
return ExprEval.of(dateTime.monthOfYear().get());
case QUARTER:
return ExprEval.of((dateTime.monthOfYear().get() - 1) / 3 + 1);
case YEAR:
return ExprEval.of(dateTime.year().get());
case ISOYEAR:
return ExprEval.of(dateTime.year().get());
case DECADE:
// The year field divided by 10, See https://www.postgresql.org/docs/10/functions-datetime.html
return ExprEval.of(dateTime.year().get() / 10);
case CENTURY:
return ExprEval.of(Math.ceil((double) dateTime.year().get() / 100));
case MILLENNIUM:
// Years in the 1900s are in the second millennium. The third millennium started January 1, 2001.
// See https://www.postgresql.org/docs/10/functions-datetime.html
return ExprEval.of(Math.ceil((double) dateTime.year().get() / 1000));
default:
throw TimestampExtractExprMacro.this.validationFailed("unhandled unit[%s]", unit);
}
}
@Override
public Expr visit(Shuttle shuttle)
{
return shuttle.visit(apply(shuttle.visitAll(args)));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
switch (unit) {
case CENTURY:
case MILLENNIUM:
return ExpressionType.DOUBLE;
default:
return ExpressionType.LONG;
}
}
@Override
public String stringify()
{
if (args.size() > 2) {
return StringUtils.format(
"%s(%s, %s, %s)",
FN_NAME,
arg.stringify(),
args.get(1).stringify(),
args.get(2).stringify()
);
}
return StringUtils.format("%s(%s, %s)", FN_NAME, arg.stringify(), args.get(1).stringify());
}
} }
return new TimestampExtractExpr(arg); @Nonnull
@Override
public ExprEval eval(final ObjectBinding bindings)
{
Object val = args.get(0).eval(bindings).value();
if (val == null) {
// Return null if the argument if null.
return ExprEval.of(null);
}
final DateTime dateTime = new DateTime(val, chronology);
return getExprEval(dateTime, unit);
}
@Override
public Expr visit(Shuttle shuttle)
{
return shuttle.visit(apply(shuttle.visitAll(args)));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return getOutputExpressionType(unit);
}
@Override
public String stringify()
{
return stringifyExpr(args);
}
}
public class TimestampExtractDynamicExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
{
private final Unit unit;
private TimestampExtractDynamicExpr(final List<Expr> args, final Unit unit)
{
super(FN_NAME, args);
this.unit = unit;
}
@Nonnull
@Override
public ExprEval eval(final ObjectBinding bindings)
{
Object val = args.get(0).eval(bindings).value();
if (val == null) {
// Return null if the argument if null.
return ExprEval.of(null);
}
final ISOChronology chronology = computeChronology(args, bindings);
final DateTime dateTime = new DateTime(val, chronology);
return getExprEval(dateTime, unit);
}
@Override
public Expr visit(Shuttle shuttle)
{
return shuttle.visit(apply(shuttle.visitAll(args)));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return getOutputExpressionType(unit);
}
@Override
public String stringify()
{
return stringifyExpr(args);
}
} }
} }

View File

@ -20,10 +20,13 @@
package org.apache.druid.query.expression; package org.apache.druid.query.expression;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -100,4 +103,29 @@ public class TimestampExtractExprMacroTest
)); ));
Assert.assertEquals(3, expression.eval(InputBindings.nilBindings()).asInt()); Assert.assertEquals(3, expression.eval(InputBindings.nilBindings()).asInt());
} }
@Test
public void testApplyExtractDowWithTimeZoneShouldBeFriday()
{
Expr expression = target.apply(
ImmutableList.of(
ExprEval.of("2023-12-15").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.DOW.toString()).toExpr(),
ExprEval.of("UTC").toExpr()
));
Assert.assertEquals(5, expression.eval(InputBindings.nilBindings()).asInt());
}
@Test
public void testApplyExtractDowWithDynamicTimeZoneShouldBeFriday()
{
Expr expression = Parser.parse("timestamp_extract(time, 'DOW', timezone)", TestExprMacroTable.INSTANCE);
Expr.ObjectBinding bindings = InputBindings.forInputSuppliers(
ImmutableMap.of(
"time", InputBindings.inputSupplier(ExpressionType.STRING, () -> "2023-12-15"),
"timezone", InputBindings.inputSupplier(ExpressionType.STRING, () -> "UTC")
)
);
Assert.assertEquals(5, expression.eval(bindings).asInt());
}
} }

View File

@ -21,6 +21,7 @@ package org.apache.druid.sql.calcite.expression.builtin;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunction;
@ -65,6 +66,23 @@ public class TimeExtractOperatorConversion implements SqlOperatorConversion
); );
} }
public static DruidExpression applyTimeExtract(
final DruidExpression timeExpression,
final TimestampExtractExprMacro.Unit unit,
final DruidExpression timeZoneExpression
)
{
return DruidExpression.ofFunctionCall(
timeExpression.getDruidType(),
"timestamp_extract",
ImmutableList.of(
timeExpression,
DruidExpression.ofStringLiteral(unit.name()),
timeZoneExpression
)
);
}
@Override @Override
public SqlFunction calciteOperator() public SqlFunction calciteOperator()
{ {
@ -89,13 +107,23 @@ public class TimeExtractOperatorConversion implements SqlOperatorConversion
StringUtils.toUpperCase(RexLiteral.stringValue(call.getOperands().get(1))) StringUtils.toUpperCase(RexLiteral.stringValue(call.getOperands().get(1)))
); );
final DateTimeZone timeZone = OperatorConversions.getOperandWithDefault( if (call.getOperands().size() > 2 && call.getOperands().get(2) instanceof RexInputRef) {
call.getOperands(), final RexNode timeZoneArg = call.getOperands().get(2);
2, final DruidExpression timeZoneExpression = Expressions.toDruidExpression(
operand -> DateTimes.inferTzFromString(RexLiteral.stringValue(operand)), plannerContext,
plannerContext.getTimeZone() rowSignature,
); timeZoneArg
);
return applyTimeExtract(timeExpression, unit, timeZoneExpression);
} else {
final DateTimeZone timeZone = OperatorConversions.getOperandWithDefault(
call.getOperands(),
2,
operand -> DateTimes.inferTzFromString(RexLiteral.stringValue(operand)),
plannerContext.getTimeZone()
);
return applyTimeExtract(timeExpression, unit, timeZone); return applyTimeExtract(timeExpression, unit, timeZone);
}
} }
} }

View File

@ -99,34 +99,33 @@ public class ExpressionsTest extends CalciteTestBase
.add("newliney", ColumnType.STRING) .add("newliney", ColumnType.STRING)
.add("tstr", ColumnType.STRING) .add("tstr", ColumnType.STRING)
.add("dstr", ColumnType.STRING) .add("dstr", ColumnType.STRING)
.add("timezone", ColumnType.STRING)
.build(); .build();
private static final Map<String, Object> BINDINGS = ImmutableMap.<String, Object>builder() private static final Map<String, Object> BINDINGS = ImmutableMap.<String, Object>builder()
.put( .put("t", DateTimes.of("2000-02-03T04:05:06").getMillis())
"t", .put("a", 10)
DateTimes.of("2000-02-03T04:05:06").getMillis() .put("b", 25)
) .put("p", 3)
.put("a", 10) .put("x", 2.25)
.put("b", 25) .put("y", 3.0)
.put("p", 3) .put("z", -2.25)
.put("x", 2.25) .put("o", 0)
.put("y", 3.0) .put("nan", Double.NaN)
.put("z", -2.25) .put("inf", Double.POSITIVE_INFINITY)
.put("o", 0) .put("-inf", Double.NEGATIVE_INFINITY)
.put("nan", Double.NaN) .put("fnan", Float.NaN)
.put("inf", Double.POSITIVE_INFINITY) .put("finf", Float.POSITIVE_INFINITY)
.put("-inf", Double.NEGATIVE_INFINITY) .put("-finf", Float.NEGATIVE_INFINITY)
.put("fnan", Float.NaN) .put("s", "foo")
.put("finf", Float.POSITIVE_INFINITY) .put("hexstr", "EF")
.put("-finf", Float.NEGATIVE_INFINITY) .put("intstr", "-100")
.put("s", "foo") .put("spacey", " hey there ")
.put("hexstr", "EF") .put("newliney", "beep\nboop")
.put("intstr", "-100") .put("tstr", "2000-02-03 04:05:06")
.put("spacey", " hey there ") .put("dstr", "2000-02-03")
.put("newliney", "beep\nboop") .put("timezone", "America/Los_Angeles")
.put("tstr", "2000-02-03 04:05:06") .build();
.put("dstr", "2000-02-03")
.build();
private ExpressionTestHelper testHelper; private ExpressionTestHelper testHelper;
@ -1844,6 +1843,17 @@ public class ExpressionsTest extends CalciteTestBase
makeExpression(ColumnType.LONG, "timestamp_extract(\"t\",'DAY','America/Los_Angeles')"), makeExpression(ColumnType.LONG, "timestamp_extract(\"t\",'DAY','America/Los_Angeles')"),
2L 2L
); );
testHelper.testExpressionString(
new TimeExtractOperatorConversion().calciteOperator(),
ImmutableList.of(
testHelper.makeInputRef("t"),
testHelper.makeLiteral("DAY"),
testHelper.makeInputRef("timezone")
),
makeExpression(ColumnType.LONG, "timestamp_extract(\"t\",'DAY',\"timezone\")"),
2L
);
} }
@Test @Test