Support to pass dynamic values to timestamp Extract function (#15586)

Fixes #15072

Before this modification , the third parameter (timezone) require to be a Literal, it will throw a error when this parameter is column Identifier.
This commit is contained in:
AlbericByte 2023-12-20 22:27:52 -08:00 committed by GitHub
parent 8a45efbf65
commit a2e65e6a89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 272 additions and 140 deletions

View File

@ -124,7 +124,6 @@ public interface NamedFunction
} }
} }
/** /**
* Helper method for implementors performing validation to check that the argument list is some expected size. * Helper method for implementors performing validation to check that the argument list is some expected size.
* *

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.expression; package org.apache.druid.query.expression;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprEval;
@ -64,50 +65,9 @@ public class TimestampExtractExprMacro implements ExprMacroTable.ExprMacro
return FN_NAME; return FN_NAME;
} }
@Override private ExprEval getExprEval(final DateTime dateTime, final Unit unit)
public Expr apply(final List<Expr> args)
{ {
validationHelperCheckArgumentRange(args, 2, 3);
if (!args.get(1).isLiteral() || args.get(1).getLiteralValue() == null) {
throw validationFailed("unit arg must be literal");
}
if (args.size() > 2) {
validationHelperCheckArgIsLiteral(args.get(2), "timezone");
}
final Expr arg = args.get(0);
final Unit unit = Unit.valueOf(StringUtils.toUpperCase((String) args.get(1).getLiteralValue()));
final DateTimeZone timeZone;
if (args.size() > 2) {
timeZone = ExprUtils.toTimeZone(args.get(2));
} else {
timeZone = DateTimeZone.UTC;
}
final ISOChronology chronology = ISOChronology.getInstance(timeZone);
class TimestampExtractExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
{
private TimestampExtractExpr(Expr arg)
{
super(FN_NAME, arg);
}
@Nonnull
@Override
public ExprEval eval(final ObjectBinding bindings)
{
Object val = arg.eval(bindings).value();
if (val == null) {
// Return null if the argument if null.
return ExprEval.of(null);
}
final DateTime dateTime = new DateTime(val, chronology);
long epoch = dateTime.getMillis() / 1000; long epoch = dateTime.getMillis() / 1000;
switch (unit) { switch (unit) {
case EPOCH: case EPOCH:
return ExprEval.of(epoch); return ExprEval.of(epoch);
@ -153,15 +113,7 @@ public class TimestampExtractExprMacro implements ExprMacroTable.ExprMacro
} }
} }
@Override private static ExpressionType getOutputExpressionType(final Unit unit)
public Expr visit(Shuttle shuttle)
{
return shuttle.visit(apply(shuttle.visitAll(args)));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{ {
switch (unit) { switch (unit) {
case CENTURY: case CENTURY:
@ -172,22 +124,137 @@ public class TimestampExtractExprMacro implements ExprMacroTable.ExprMacro
} }
} }
@Override private static String stringifyExpr(final List<Expr> args)
public String stringify()
{ {
if (args.size() > 2) { if (args.size() > 2) {
return StringUtils.format( return StringUtils.format(
"%s(%s, %s, %s)", "%s(%s, %s, %s)",
FN_NAME, FN_NAME,
arg.stringify(), args.get(0).stringify(),
args.get(1).stringify(), args.get(1).stringify(),
args.get(2).stringify() args.get(2).stringify()
); );
} }
return StringUtils.format("%s(%s, %s)", FN_NAME, arg.stringify(), args.get(1).stringify()); return StringUtils.format("%s(%s, %s)", FN_NAME, args.get(0).stringify(), args.get(1).stringify());
}
private static ISOChronology computeChronology(final List<Expr> args, final Expr.ObjectBinding bindings)
{
String timeZoneVal = (String) args.get(2).eval(bindings).value();
return timeZoneVal != null
? ISOChronology.getInstance(DateTimes.inferTzFromString(timeZoneVal))
: ISOChronology.getInstanceUTC();
}
@Override
public Expr apply(final List<Expr> args)
{
validationHelperCheckArgumentRange(args, 2, 3);
if (!args.get(1).isLiteral() || args.get(1).getLiteralValue() == null) {
throw validationFailed("unit arg must be literal");
}
final Unit unit = Unit.valueOf(StringUtils.toUpperCase((String) args.get(1).getLiteralValue()));
if (args.size() > 2) {
if (args.get(2).isLiteral()) {
DateTimeZone timeZone = ExprUtils.toTimeZone(args.get(2));
ISOChronology chronology = ISOChronology.getInstance(timeZone);
return new TimestampExtractExpr(args, unit, chronology);
} else {
return new TimestampExtractDynamicExpr(args, unit);
}
}
return new TimestampExtractExpr(args, unit, ISOChronology.getInstanceUTC());
}
public class TimestampExtractExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
{
private final ISOChronology chronology;
private final Unit unit;
private TimestampExtractExpr(final List<Expr> args, final Unit unit, final ISOChronology chronology)
{
super(FN_NAME, args);
this.unit = unit;
this.chronology = chronology;
}
@Nonnull
@Override
public ExprEval eval(final ObjectBinding bindings)
{
Object val = args.get(0).eval(bindings).value();
if (val == null) {
// Return null if the argument if null.
return ExprEval.of(null);
}
final DateTime dateTime = new DateTime(val, chronology);
return getExprEval(dateTime, unit);
}
@Override
public Expr visit(Shuttle shuttle)
{
return shuttle.visit(apply(shuttle.visitAll(args)));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return getOutputExpressionType(unit);
}
@Override
public String stringify()
{
return stringifyExpr(args);
} }
} }
return new TimestampExtractExpr(arg); public class TimestampExtractDynamicExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
{
private final Unit unit;
private TimestampExtractDynamicExpr(final List<Expr> args, final Unit unit)
{
super(FN_NAME, args);
this.unit = unit;
}
@Nonnull
@Override
public ExprEval eval(final ObjectBinding bindings)
{
Object val = args.get(0).eval(bindings).value();
if (val == null) {
// Return null if the argument if null.
return ExprEval.of(null);
}
final ISOChronology chronology = computeChronology(args, bindings);
final DateTime dateTime = new DateTime(val, chronology);
return getExprEval(dateTime, unit);
}
@Override
public Expr visit(Shuttle shuttle)
{
return shuttle.visit(apply(shuttle.visitAll(args)));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return getOutputExpressionType(unit);
}
@Override
public String stringify()
{
return stringifyExpr(args);
}
} }
} }

View File

@ -20,10 +20,13 @@
package org.apache.druid.query.expression; package org.apache.druid.query.expression;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -100,4 +103,29 @@ public class TimestampExtractExprMacroTest
)); ));
Assert.assertEquals(3, expression.eval(InputBindings.nilBindings()).asInt()); Assert.assertEquals(3, expression.eval(InputBindings.nilBindings()).asInt());
} }
@Test
public void testApplyExtractDowWithTimeZoneShouldBeFriday()
{
Expr expression = target.apply(
ImmutableList.of(
ExprEval.of("2023-12-15").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.DOW.toString()).toExpr(),
ExprEval.of("UTC").toExpr()
));
Assert.assertEquals(5, expression.eval(InputBindings.nilBindings()).asInt());
}
@Test
public void testApplyExtractDowWithDynamicTimeZoneShouldBeFriday()
{
Expr expression = Parser.parse("timestamp_extract(time, 'DOW', timezone)", TestExprMacroTable.INSTANCE);
Expr.ObjectBinding bindings = InputBindings.forInputSuppliers(
ImmutableMap.of(
"time", InputBindings.inputSupplier(ExpressionType.STRING, () -> "2023-12-15"),
"timezone", InputBindings.inputSupplier(ExpressionType.STRING, () -> "UTC")
)
);
Assert.assertEquals(5, expression.eval(bindings).asInt());
}
} }

View File

@ -21,6 +21,7 @@ package org.apache.druid.sql.calcite.expression.builtin;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunction;
@ -65,6 +66,23 @@ public class TimeExtractOperatorConversion implements SqlOperatorConversion
); );
} }
public static DruidExpression applyTimeExtract(
final DruidExpression timeExpression,
final TimestampExtractExprMacro.Unit unit,
final DruidExpression timeZoneExpression
)
{
return DruidExpression.ofFunctionCall(
timeExpression.getDruidType(),
"timestamp_extract",
ImmutableList.of(
timeExpression,
DruidExpression.ofStringLiteral(unit.name()),
timeZoneExpression
)
);
}
@Override @Override
public SqlFunction calciteOperator() public SqlFunction calciteOperator()
{ {
@ -89,6 +107,15 @@ public class TimeExtractOperatorConversion implements SqlOperatorConversion
StringUtils.toUpperCase(RexLiteral.stringValue(call.getOperands().get(1))) StringUtils.toUpperCase(RexLiteral.stringValue(call.getOperands().get(1)))
); );
if (call.getOperands().size() > 2 && call.getOperands().get(2) instanceof RexInputRef) {
final RexNode timeZoneArg = call.getOperands().get(2);
final DruidExpression timeZoneExpression = Expressions.toDruidExpression(
plannerContext,
rowSignature,
timeZoneArg
);
return applyTimeExtract(timeExpression, unit, timeZoneExpression);
} else {
final DateTimeZone timeZone = OperatorConversions.getOperandWithDefault( final DateTimeZone timeZone = OperatorConversions.getOperandWithDefault(
call.getOperands(), call.getOperands(),
2, 2,
@ -98,4 +125,5 @@ public class TimeExtractOperatorConversion implements SqlOperatorConversion
return applyTimeExtract(timeExpression, unit, timeZone); return applyTimeExtract(timeExpression, unit, timeZone);
} }
}
} }

View File

@ -99,13 +99,11 @@ public class ExpressionsTest extends CalciteTestBase
.add("newliney", ColumnType.STRING) .add("newliney", ColumnType.STRING)
.add("tstr", ColumnType.STRING) .add("tstr", ColumnType.STRING)
.add("dstr", ColumnType.STRING) .add("dstr", ColumnType.STRING)
.add("timezone", ColumnType.STRING)
.build(); .build();
private static final Map<String, Object> BINDINGS = ImmutableMap.<String, Object>builder() private static final Map<String, Object> BINDINGS = ImmutableMap.<String, Object>builder()
.put( .put("t", DateTimes.of("2000-02-03T04:05:06").getMillis())
"t",
DateTimes.of("2000-02-03T04:05:06").getMillis()
)
.put("a", 10) .put("a", 10)
.put("b", 25) .put("b", 25)
.put("p", 3) .put("p", 3)
@ -126,6 +124,7 @@ public class ExpressionsTest extends CalciteTestBase
.put("newliney", "beep\nboop") .put("newliney", "beep\nboop")
.put("tstr", "2000-02-03 04:05:06") .put("tstr", "2000-02-03 04:05:06")
.put("dstr", "2000-02-03") .put("dstr", "2000-02-03")
.put("timezone", "America/Los_Angeles")
.build(); .build();
private ExpressionTestHelper testHelper; private ExpressionTestHelper testHelper;
@ -1844,6 +1843,17 @@ public class ExpressionsTest extends CalciteTestBase
makeExpression(ColumnType.LONG, "timestamp_extract(\"t\",'DAY','America/Los_Angeles')"), makeExpression(ColumnType.LONG, "timestamp_extract(\"t\",'DAY','America/Los_Angeles')"),
2L 2L
); );
testHelper.testExpressionString(
new TimeExtractOperatorConversion().calciteOperator(),
ImmutableList.of(
testHelper.makeInputRef("t"),
testHelper.makeLiteral("DAY"),
testHelper.makeInputRef("timezone")
),
makeExpression(ColumnType.LONG, "timestamp_extract(\"t\",'DAY',\"timezone\")"),
2L
);
} }
@Test @Test