SQL: Fix natural comparator selection for groupBy. (#14075)

* SQL: Fix natural comparator selection for groupBy.

DruidQuery.computeSorting had some unique logic for finding natural
comparators for SQL types. It should be using getStringComparatorForRelDataType
instead.

One good effect here is that the comparator for BOOLEAN is now
NUMERIC rather than LEXICOGRAPHIC. The test case illustrates this.

* Remove msqCompatible, for now.

* Fix test.
This commit is contained in:
Gian Merlino 2023-04-14 18:44:43 -07:00 committed by GitHub
parent eeed5ed7e2
commit a8eb3f2f57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 12 deletions

View File

@ -41,7 +41,6 @@ import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.ImmutableBitSet;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
@ -64,7 +63,6 @@ import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
import org.apache.druid.query.operator.WindowOperatorQuery; import org.apache.druid.query.operator.WindowOperatorQuery;
import org.apache.druid.query.ordering.StringComparator; import org.apache.druid.query.ordering.StringComparator;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.scan.ScanQuery;
import org.apache.druid.query.timeboundary.TimeBoundaryQuery; import org.apache.druid.query.timeboundary.TimeBoundaryQuery;
import org.apache.druid.query.timeseries.TimeseriesQuery; import org.apache.druid.query.timeseries.TimeseriesQuery;
@ -209,7 +207,6 @@ public class DruidQuery
// Now the fun begins. // Now the fun begins.
final DimFilter filter; final DimFilter filter;
final Projection selectProjection; final Projection selectProjection;
final Projection unnestProjection;
final Grouping grouping; final Grouping grouping;
final Sorting sorting; final Sorting sorting;
final Windowing windowing; final Windowing windowing;
@ -438,6 +435,7 @@ public class DruidQuery
* @param rowSignature source row signature * @param rowSignature source row signature
* @param virtualColumnRegistry re-usable virtual column references * @param virtualColumnRegistry re-usable virtual column references
* @param typeFactory factory for SQL types * @param typeFactory factory for SQL types
*
* @return dimensions * @return dimensions
* *
* @throws CannotBuildQueryException if dimensions cannot be computed * @throws CannotBuildQueryException if dimensions cannot be computed
@ -625,14 +623,7 @@ public class DruidQuery
throw new ISE("Don't know what to do with direction[%s]", collation.getDirection()); throw new ISE("Don't know what to do with direction[%s]", collation.getDirection());
} }
final SqlTypeName sortExpressionType = sortExpression.getType().getSqlTypeName(); comparator = Calcites.getStringComparatorForRelDataType(sortExpression.getType());
if (SqlTypeName.NUMERIC_TYPES.contains(sortExpressionType)
|| SqlTypeName.TIMESTAMP == sortExpressionType
|| SqlTypeName.DATE == sortExpressionType) {
comparator = StringComparators.NUMERIC;
} else {
comparator = StringComparators.LEXICOGRAPHIC;
}
if (sortExpression.isA(SqlKind.INPUT_REF)) { if (sortExpression.isA(SqlKind.INPUT_REF)) {
final RexInputRef ref = (RexInputRef) sortExpression; final RexInputRef ref = (RexInputRef) sortExpression;

View File

@ -5222,7 +5222,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
.filters(new InDimFilter("m2", new HashSet<>(Arrays.asList(null, "A")))) .filters(new InDimFilter("m2", new HashSet<>(Arrays.asList(null, "A"))))
.virtualColumns(expressionVirtualColumn("v0", "notnull(\"m2\")", ColumnType.LONG)) .virtualColumns(expressionVirtualColumn("v0", "notnull(\"m2\")", ColumnType.LONG))
.dimension(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG)) .dimension(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))
.metric(new InvertedTopNMetricSpec(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC))) .metric(new InvertedTopNMetricSpec(new DimensionTopNMetricSpec(null, StringComparators.NUMERIC)))
.aggregators(new CountAggregatorFactory("a0")) .aggregators(new CountAggregatorFactory("a0"))
.threshold(1) .threshold(1)
.build() .build()

View File

@ -8764,6 +8764,61 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
); );
} }
@Test
public void testGroupByOrderByBoolean()
{
// Not msqCompatible until https://github.com/apache/druid/pull/14046 is merged.
testQuery(
"SELECT dim1 = 'abc', COUNT(*) FROM druid.foo GROUP BY 1 ORDER BY 1, 2 LIMIT 2",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(
expressionVirtualColumn(
"v0",
"(\"dim1\" == 'abc')",
ColumnType.LONG
)
)
.setDimensions(
dimensions(
new DefaultDimensionSpec(
"v0",
"d0",
ColumnType.LONG
)
)
)
.setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0")))
.setLimitSpec(
new DefaultLimitSpec(
ImmutableList.of(
new OrderByColumnSpec(
"d0",
Direction.ASCENDING,
StringComparators.NUMERIC
),
new OrderByColumnSpec(
"a0",
Direction.ASCENDING,
StringComparators.NUMERIC
)
),
2
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{false, 5L},
new Object[]{true, 1L}
)
);
}
@Test @Test
public void testGroupByFloorTimeAndOneOtherDimensionWithOrderBy() public void testGroupByFloorTimeAndOneOtherDimensionWithOrderBy()
{ {