diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java index 857c8cb0d12..251a5d7925c 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java @@ -133,6 +133,8 @@ public class DruidSqlValidator extends BaseDruidSqlValidator throw Util.unexpected(windowOrId.getKind()); } + updateBoundsIfNeeded(targetWindow); + @Nullable SqlNode lowerBound = targetWindow.getLowerBound(); @Nullable @@ -144,17 +146,6 @@ public class DruidSqlValidator extends BaseDruidSqlValidator ); } - if (lowerBound != null && upperBound == null) { - if (lowerBound.getKind() == SqlKind.FOLLOWING || SqlWindow.isUnboundedFollowing(lowerBound)) { - upperBound = lowerBound; - lowerBound = SqlWindow.createCurrentRow(SqlParserPos.ZERO); - } else { - upperBound = SqlWindow.createCurrentRow(SqlParserPos.ZERO); - } - targetWindow.setLowerBound(lowerBound); - targetWindow.setUpperBound(upperBound); - } - boolean hasBounds = lowerBound != null || upperBound != null; if (call.getKind() == SqlKind.NTILE && hasBounds) { throw buildCalciteContextException( @@ -758,6 +749,28 @@ public class DruidSqlValidator extends BaseDruidSqlValidator || SqlWindow.isUnboundedPreceding(bound); } + /** + * Checks if any bound is null and updates with CURRENT ROW + */ + private void updateBoundsIfNeeded(SqlWindow window) + { + @Nullable + SqlNode lowerBound = window.getLowerBound(); + @Nullable + SqlNode upperBound = window.getUpperBound(); + + if (lowerBound != null && upperBound == null) { + if (lowerBound.getKind() == SqlKind.FOLLOWING || SqlWindow.isUnboundedFollowing(lowerBound)) { + upperBound = lowerBound; + lowerBound = SqlWindow.createCurrentRow(SqlParserPos.ZERO); + } else { + upperBound = SqlWindow.createCurrentRow(SqlParserPos.ZERO); + } + window.setLowerBound(lowerBound); + window.setUpperBound(upperBound); + } + } + @Override public void validateCall(SqlCall call, SqlValidatorScope scope) { @@ -812,6 +825,10 @@ public class DruidSqlValidator extends BaseDruidSqlValidator sqlNode ); } + if (sqlNode instanceof SqlWindow) { + SqlWindow window = (SqlWindow) sqlNode; + updateBoundsIfNeeded(window); + } } super.validateWindowClause(select); } diff --git a/sql/src/test/resources/calcite/tests/window/defaultBoundCurrentRow.sqlTest b/sql/src/test/resources/calcite/tests/window/defaultBoundCurrentRow.sqlTest index aa0a4a2a019..d5289523135 100644 --- a/sql/src/test/resources/calcite/tests/window/defaultBoundCurrentRow.sqlTest +++ b/sql/src/test/resources/calcite/tests/window/defaultBoundCurrentRow.sqlTest @@ -7,10 +7,11 @@ sql: | count(*) OVER (partition by dim2 ORDER BY dim1 ROWS 1 PRECEDING), count(*) OVER (partition by dim2 ORDER BY dim1 ROWS CURRENT ROW), count(*) OVER (partition by dim2 ORDER BY dim1 ROWS 1 FOLLOWING), - count(*) OVER (partition by dim2 ORDER BY dim1 ROWS UNBOUNDED FOLLOWING) + count(*) OVER W FROM numfoo WHERE dim2 IN ('a', 'abc') GROUP BY dim2, dim1 + WINDOW W AS (partition by dim2 ORDER BY dim1 ROWS UNBOUNDED FOLLOWING) expectedOperators: - {"type":"naiveSort","columns":[{"column":"_d1","direction":"ASC"},{"column":"_d0","direction":"ASC"}]}