From 8252d72e2a87422f3bdb293e16b12d3a6c4aea90 Mon Sep 17 00:00:00 2001 From: Zoltan Haindrich Date: Tue, 12 Mar 2024 17:14:31 +0100 Subject: [PATCH] Pull up literals in InputAccessor (#16033) * Pull up literals in InputAccessor * pull up literals in `InputAccessor` * remove the need to pass `constants` of `Window` operator Fixes #15353 * update test * enable relax_nulls --- .../hll/sql/HllSketchSqlAggregatorTest.java | 22 +++++ .../druid/sql/calcite/rel/DruidQuery.java | 5 +- .../druid/sql/calcite/rel/InputAccessor.java | 84 ++++++++++++------- .../druid/sql/calcite/rel/Windowing.java | 5 +- 4 files changed, 82 insertions(+), 34 deletions(-) diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index 22204a5f9a4..538ca817180 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -1268,6 +1268,7 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest ); } + /** * This is a test in a similar vein to {@link #testEstimateStringAndDoubleAreDifferent()} except here we are * ensuring that float values and doubles values are considered equivalent. The expected initial inputs were @@ -1318,6 +1319,27 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest ); } + @Test + public void testDsHllOnTopOfNested() + { + // this query was not planable: https://github.com/apache/druid/issues/15353 + testBuilder() + .sql( + "SELECT d1,dim2,APPROX_COUNT_DISTINCT_DS_HLL(dim2, 18) as val" + + " FROM (select d1,dim1,dim2 from druid.foo group by d1,dim1,dim2 order by dim1 limit 3) t " + + " group by 1,2" + ) + .expectedResults( + ResultMatchMode.RELAX_NULLS, + ImmutableList.of( + new Object[] {null, "a", 1L}, + new Object[] {"1.0", "a", 1L}, + new Object[] {"1.7", null, 0L} + ) + ) + .run(); + } + private ExpressionVirtualColumn makeSketchEstimateExpression(String outputName, String field) { return new ExpressionVirtualColumn( diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java index 6e0bab21277..f3ea896b842 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java @@ -591,10 +591,9 @@ public class DruidQuery virtualColumnRegistry, rexBuilder, InputAccessor.buildFor( - rexBuilder, - rowSignature, + aggregate, partialQuery.getSelectProject(), - null), + rowSignature), aggregations, aggName, aggCall, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java index 12c81d88756..aeadab0fa9a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java @@ -20,12 +20,17 @@ package org.apache.druid.sql.calcite.rel; import com.google.common.collect.ImmutableList; +import org.apache.calcite.plan.RelOptPredicateList; +import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Window; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.druid.segment.column.RowSignature; -import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.table.RowSignatures; import javax.annotation.Nullable; import java.util.List; @@ -38,43 +43,67 @@ import java.util.stream.Collectors; */ public class InputAccessor { - private final Project project; - private final ImmutableList constants; - private final RexBuilder rexBuilder; + private final RelNode relNode; + @Nullable + private final Project flattenedProject; private final RowSignature inputRowSignature; + @Nullable + private final ImmutableList constants; + private final RelNode inputRelNode; + private final RelDataType inputRelRowType; + private final RelOptPredicateList predicates; private final int inputFieldCount; + private final RelDataType inputDruidRowType; public static InputAccessor buildFor( - RexBuilder rexBuilder, - RowSignature inputRowSignature, - @Nullable Project project, - @Nullable ImmutableList constants) + RelNode relNode, + @Nullable Project flattenedProject, + RowSignature rowSignature) { - return new InputAccessor(rexBuilder, inputRowSignature, project, constants); + return new InputAccessor( + relNode, + flattenedProject, + rowSignature + ); } private InputAccessor( - RexBuilder rexBuilder, - RowSignature inputRowSignature, - Project project, - ImmutableList constants) + RelNode relNode, + Project flattenedProject, + RowSignature rowSignature) { - this.rexBuilder = rexBuilder; - this.inputRowSignature = inputRowSignature; - this.project = project; - this.constants = constants; - this.inputFieldCount = project != null ? project.getRowType().getFieldCount() : inputRowSignature.size(); + this.relNode = relNode; + this.constants = getConstants(relNode); + this.inputRelNode = relNode.getInput(0).stripped(); + this.flattenedProject = flattenedProject; + this.inputRowSignature = rowSignature; + this.inputRelRowType = inputRelNode.getRowType(); + this.predicates = relNode.getCluster().getMetadataQuery().getPulledUpPredicates(inputRelNode); + this.inputFieldCount = inputRelRowType.getFieldCount(); + this.inputDruidRowType = RowSignatures.toRelDataType(inputRowSignature, getRexBuilder().getTypeFactory()); + } + + private ImmutableList getConstants(RelNode relNode) + { + if (relNode instanceof Window) { + return ((Window) relNode).constants; + } + return null; } public RexNode getField(int argIndex) { - if (argIndex < inputFieldCount) { - return Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - inputRowSignature, - project, - argIndex); + RexInputRef inputRef = RexInputRef.of(argIndex, inputRelRowType); + RexNode constant = predicates.constantMap.get(inputRef); + if (constant != null) { + return constant; + } + if (flattenedProject != null) { + return flattenedProject.getProjects().get(argIndex); + } else { + return RexInputRef.of(argIndex, inputDruidRowType); + } } else { return constants.get(argIndex - inputFieldCount); } @@ -90,18 +119,17 @@ public class InputAccessor public @Nullable Project getProject() { - return project; + return flattenedProject; } - public RexBuilder getRexBuilder() { - return rexBuilder; + return relNode.getCluster().getRexBuilder(); } - public RowSignature getInputRowSignature() { return inputRowSignature; } + } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java index 60f0f1d539d..c96b3bdd39f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java @@ -180,10 +180,9 @@ public class Windowing virtualColumnRegistry, rexBuilder, InputAccessor.buildFor( - rexBuilder, - sourceRowSignature, + window, partialQuery.getSelectProject(), - window.constants), + sourceRowSignature), Collections.emptyList(), aggName, aggregateCall,