Support projection after sorting in SQL (#5788)

* Add sort project

* add more test

* address comments
This commit is contained in:
Jihoon Son 2018-06-11 08:33:47 -07:00 committed by Gian Merlino
parent e43e5ebbcd
commit fe4d678aac
9 changed files with 582 additions and 98 deletions

View File

@ -36,6 +36,7 @@ import io.druid.sql.calcite.filtration.Filtration;
import io.druid.sql.calcite.table.RowSignature;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@ -112,7 +113,7 @@ public class Aggregation
public static Aggregation create(final PostAggregator postAggregator)
{
return new Aggregation(ImmutableList.of(), ImmutableList.of(), postAggregator);
return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
}
public static Aggregation create(

View File

@ -89,6 +89,7 @@ import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.OptionalInt;
import java.util.TreeSet;
import java.util.stream.Collectors;
@ -105,9 +106,11 @@ public class DruidQuery
private final DimFilter filter;
private final SelectProjection selectProjection;
private final Grouping grouping;
private final SortProject sortProject;
private final DefaultLimitSpec limitSpec;
private final RowSignature outputRowSignature;
private final RelDataType outputRowType;
private final DefaultLimitSpec limitSpec;
private final Query query;
public DruidQuery(
@ -128,15 +131,22 @@ public class DruidQuery
this.selectProjection = computeSelectProjection(partialQuery, plannerContext, sourceRowSignature);
this.grouping = computeGrouping(partialQuery, plannerContext, sourceRowSignature, rexBuilder);
final RowSignature sortingInputRowSignature;
if (this.selectProjection != null) {
this.outputRowSignature = this.selectProjection.getOutputRowSignature();
sortingInputRowSignature = this.selectProjection.getOutputRowSignature();
} else if (this.grouping != null) {
this.outputRowSignature = this.grouping.getOutputRowSignature();
sortingInputRowSignature = this.grouping.getOutputRowSignature();
} else {
this.outputRowSignature = sourceRowSignature;
sortingInputRowSignature = sourceRowSignature;
}
this.limitSpec = computeLimitSpec(partialQuery, this.outputRowSignature);
this.sortProject = computeSortProject(partialQuery, plannerContext, sortingInputRowSignature, grouping);
// outputRowSignature is used only for scan and select query, and thus sort and grouping must be null
this.outputRowSignature = sortProject == null ? sortingInputRowSignature : sortProject.getOutputRowSignature();
this.limitSpec = computeLimitSpec(partialQuery, sortingInputRowSignature);
this.query = computeQuery();
}
@ -235,7 +245,7 @@ public class DruidQuery
)
{
final Aggregate aggregate = partialQuery.getAggregate();
final Project postProject = partialQuery.getPostProject();
final Project aggregateProject = partialQuery.getAggregateProject();
if (aggregate == null) {
return null;
@ -265,49 +275,27 @@ public class DruidQuery
plannerContext
);
if (postProject == null) {
if (aggregateProject == null) {
return Grouping.create(dimensions, aggregations, havingFilter, aggregateRowSignature);
} else {
final List<String> rowOrder = new ArrayList<>();
int outputNameCounter = 0;
for (final RexNode postAggregatorRexNode : postProject.getChildExps()) {
// Attempt to convert to PostAggregator.
final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
plannerContext,
aggregateRowSignature,
postAggregatorRexNode
);
if (postAggregatorExpression == null) {
throw new CannotBuildQueryException(postProject, postAggregatorRexNode);
}
if (postAggregatorDirectColumnIsOk(aggregateRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
// Direct column access, without any type cast as far as Druid's runtime is concerned.
// (There might be a SQL-level type cast that we don't care about)
rowOrder.add(postAggregatorExpression.getDirectColumn());
} else {
final String postAggregatorName = "p" + outputNameCounter++;
final PostAggregator postAggregator = new ExpressionPostAggregator(
postAggregatorName,
postAggregatorExpression.getExpression(),
null,
plannerContext.getExprMacroTable()
);
aggregations.add(Aggregation.create(postAggregator));
rowOrder.add(postAggregator.getName());
}
}
final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
plannerContext,
aggregateRowSignature,
aggregateProject,
0
);
projectRowOrderAndPostAggregations.postAggregations.forEach(
postAggregator -> aggregations.add(Aggregation.create(postAggregator))
);
// Remove literal dimensions that did not appear in the projection. This is useful for queries
// like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can generate, and for which we don't
// actually want to include a dimension 'dummy'.
final ImmutableBitSet postProjectBits = RelOptUtil.InputFinder.bits(postProject.getChildExps(), null);
final ImmutableBitSet aggregateProjectBits = RelOptUtil.InputFinder.bits(aggregateProject.getChildExps(), null);
for (int i = dimensions.size() - 1; i >= 0; i--) {
final DimensionExpression dimension = dimensions.get(i);
if (Parser.parse(dimension.getDruidExpression().getExpression(), plannerContext.getExprMacroTable())
.isLiteral() && !postProjectBits.get(i)) {
.isLiteral() && !aggregateProjectBits.get(i)) {
dimensions.remove(i);
}
}
@ -316,11 +304,98 @@ public class DruidQuery
dimensions,
aggregations,
havingFilter,
RowSignature.from(rowOrder, postProject.getRowType())
RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, aggregateProject.getRowType())
);
}
}
@Nullable
private SortProject computeSortProject(
PartialDruidQuery partialQuery,
PlannerContext plannerContext,
RowSignature sortingInputRowSignature,
Grouping grouping
)
{
final Project sortProject = partialQuery.getSortProject();
if (sortProject == null) {
return null;
} else {
final List<PostAggregator> postAggregators = grouping.getPostAggregators();
final OptionalInt maybeMaxCounter = postAggregators
.stream()
.mapToInt(postAggregator -> Integer.parseInt(postAggregator.getName().substring(1)))
.max();
final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
plannerContext,
sortingInputRowSignature,
sortProject,
maybeMaxCounter.orElse(-1) + 1 // 0 if max doesn't exist
);
return new SortProject(
sortingInputRowSignature,
projectRowOrderAndPostAggregations.postAggregations,
RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, sortProject.getRowType())
);
}
}
private static class ProjectRowOrderAndPostAggregations
{
private final List<String> rowOrder;
private final List<PostAggregator> postAggregations;
ProjectRowOrderAndPostAggregations(List<String> rowOrder, List<PostAggregator> postAggregations)
{
this.rowOrder = rowOrder;
this.postAggregations = postAggregations;
}
}
private static ProjectRowOrderAndPostAggregations computePostAggregations(
PlannerContext plannerContext,
RowSignature inputRowSignature,
Project project,
int outputNameCounter
)
{
final List<String> rowOrder = new ArrayList<>();
final List<PostAggregator> aggregations = new ArrayList<>();
for (final RexNode postAggregatorRexNode : project.getChildExps()) {
// Attempt to convert to PostAggregator.
final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
plannerContext,
inputRowSignature,
postAggregatorRexNode
);
if (postAggregatorExpression == null) {
throw new CannotBuildQueryException(project, postAggregatorRexNode);
}
if (postAggregatorDirectColumnIsOk(inputRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
// Direct column access, without any type cast as far as Druid's runtime is concerned.
// (There might be a SQL-level type cast that we don't care about)
rowOrder.add(postAggregatorExpression.getDirectColumn());
} else {
final String postAggregatorName = "p" + outputNameCounter++;
final PostAggregator postAggregator = new ExpressionPostAggregator(
postAggregatorName,
postAggregatorExpression.getExpression(),
null,
plannerContext.getExprMacroTable()
);
aggregations.add(postAggregator);
rowOrder.add(postAggregator.getName());
}
}
return new ProjectRowOrderAndPostAggregations(rowOrder, aggregations);
}
/**
* Returns dimensions corresponding to {@code aggregate.getGroupSet()}, in the same order.
*
@ -540,18 +615,20 @@ public class DruidQuery
{
final List<VirtualColumn> retVal = new ArrayList<>();
if (grouping != null) {
if (includeDimensions) {
for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
if (selectProjection != null) {
retVal.addAll(selectProjection.getVirtualColumns());
} else {
if (grouping != null) {
if (includeDimensions) {
for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
}
}
for (Aggregation aggregation : grouping.getAggregations()) {
retVal.addAll(aggregation.getVirtualColumns());
}
}
for (Aggregation aggregation : grouping.getAggregations()) {
retVal.addAll(aggregation.getVirtualColumns());
}
} else if (selectProjection != null) {
retVal.addAll(selectProjection.getVirtualColumns());
}
return VirtualColumns.create(retVal);
@ -567,6 +644,11 @@ public class DruidQuery
return limitSpec;
}
public SortProject getSortProject()
{
return sortProject;
}
public RelDataType getOutputRowType()
{
return outputRowType;
@ -667,7 +749,6 @@ public class DruidQuery
if (limitSpec != null) {
// If there is a limit spec, timeseries cannot LIMIT; and must be ORDER BY time (or nothing).
if (limitSpec.isLimited()) {
return null;
}
@ -797,6 +878,11 @@ public class DruidQuery
final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature);
final List<PostAggregator> postAggregators = new ArrayList<>(grouping.getPostAggregators());
if (sortProject != null) {
postAggregators.addAll(sortProject.getPostAggregators());
}
return new GroupByQuery(
dataSource,
filtration.getQuerySegmentSpec(),
@ -805,7 +891,7 @@ public class DruidQuery
Granularities.ALL,
grouping.getDimensionSpecs(),
grouping.getAggregatorFactories(),
grouping.getPostAggregators(),
postAggregators,
grouping.getHavingFilter() != null ? new DimFilterHavingSpec(grouping.getHavingFilter(), true) : null,
limitSpec,
ImmutableSortedMap.copyOf(plannerContext.getQueryContext())

View File

@ -220,14 +220,18 @@ public class DruidQueryRel extends DruidRel<DruidQueryRel>
cost += COST_PER_COLUMN * partialQuery.getAggregate().getAggCallList().size();
}
if (partialQuery.getPostProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getPostProject().getChildExps().size();
if (partialQuery.getAggregateProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getAggregateProject().getChildExps().size();
}
if (partialQuery.getSort() != null && partialQuery.getSort().fetch != null) {
cost *= COST_LIMIT_MULTIPLIER;
}
if (partialQuery.getSortProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getSortProject().getChildExps().size();
}
if (partialQuery.getHavingFilter() != null) {
cost *= COST_HAVING_MULTIPLIER;
}

View File

@ -358,8 +358,12 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
newPartialQuery = newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter());
}
if (leftPartialQuery.getPostProject() != null) {
newPartialQuery = newPartialQuery.withPostProject(leftPartialQuery.getPostProject());
if (leftPartialQuery.getAggregateProject() != null) {
newPartialQuery = newPartialQuery.withAggregateProject(leftPartialQuery.getAggregateProject());
}
if (leftPartialQuery.getSortProject() != null) {
newPartialQuery = newPartialQuery.withSortProject(leftPartialQuery.getSortProject());
}
if (leftPartialQuery.getSort() != null) {

View File

@ -46,8 +46,9 @@ public class PartialDruidQuery
private final Sort selectSort;
private final Aggregate aggregate;
private final Filter havingFilter;
private final Project postProject;
private final Project aggregateProject;
private final Sort sort;
private final Project sortProject;
public enum Stage
{
@ -57,8 +58,9 @@ public class PartialDruidQuery
SELECT_SORT,
AGGREGATE,
HAVING_FILTER,
POST_PROJECT,
SORT
AGGREGATE_PROJECT,
SORT,
SORT_PROJECT
}
public PartialDruidQuery(
@ -67,9 +69,10 @@ public class PartialDruidQuery
final Project selectProject,
final Sort selectSort,
final Aggregate aggregate,
final Project postProject,
final Project aggregateProject,
final Filter havingFilter,
final Sort sort
final Sort sort,
final Project sortProject
)
{
this.scan = Preconditions.checkNotNull(scan, "scan");
@ -77,14 +80,15 @@ public class PartialDruidQuery
this.selectProject = selectProject;
this.selectSort = selectSort;
this.aggregate = aggregate;
this.postProject = postProject;
this.aggregateProject = aggregateProject;
this.havingFilter = havingFilter;
this.sort = sort;
this.sortProject = sortProject;
}
public static PartialDruidQuery create(final RelNode scanRel)
{
return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null);
return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null, null);
}
public RelNode getScan()
@ -117,9 +121,9 @@ public class PartialDruidQuery
return havingFilter;
}
public Project getPostProject()
public Project getAggregateProject()
{
return postProject;
return aggregateProject;
}
public Sort getSort()
@ -127,6 +131,11 @@ public class PartialDruidQuery
return sort;
}
public Project getSortProject()
{
return sortProject;
}
public PartialDruidQuery withWhereFilter(final Filter newWhereFilter)
{
validateStage(Stage.WHERE_FILTER);
@ -136,9 +145,10 @@ public class PartialDruidQuery
selectProject,
selectSort,
aggregate,
postProject,
aggregateProject,
havingFilter,
sort
sort,
sortProject
);
}
@ -151,9 +161,10 @@ public class PartialDruidQuery
newSelectProject,
selectSort,
aggregate,
postProject,
aggregateProject,
havingFilter,
sort
sort,
sortProject
);
}
@ -166,9 +177,10 @@ public class PartialDruidQuery
selectProject,
newSelectSort,
aggregate,
postProject,
aggregateProject,
havingFilter,
sort
sort,
sortProject
);
}
@ -181,9 +193,10 @@ public class PartialDruidQuery
selectProject,
selectSort,
newAggregate,
postProject,
aggregateProject,
havingFilter,
sort
sort,
sortProject
);
}
@ -196,24 +209,26 @@ public class PartialDruidQuery
selectProject,
selectSort,
aggregate,
postProject,
aggregateProject,
newHavingFilter,
sort
sort,
sortProject
);
}
public PartialDruidQuery withPostProject(final Project newPostProject)
public PartialDruidQuery withAggregateProject(final Project newAggregateProject)
{
validateStage(Stage.POST_PROJECT);
validateStage(Stage.AGGREGATE_PROJECT);
return new PartialDruidQuery(
scan,
whereFilter,
selectProject,
selectSort,
aggregate,
newPostProject,
newAggregateProject,
havingFilter,
sort
sort,
sortProject
);
}
@ -226,9 +241,26 @@ public class PartialDruidQuery
selectProject,
selectSort,
aggregate,
postProject,
aggregateProject,
havingFilter,
newSort
newSort,
sortProject
);
}
public PartialDruidQuery withSortProject(final Project newSortProject)
{
validateStage(Stage.SORT_PROJECT);
return new PartialDruidQuery(
scan,
whereFilter,
selectProject,
selectSort,
aggregate,
aggregateProject,
havingFilter,
sort,
newSortProject
);
}
@ -265,6 +297,9 @@ public class PartialDruidQuery
} else if (stage.compareTo(Stage.AGGREGATE) >= 0 && selectSort != null) {
// Cannot do any aggregations after a select + sort.
return false;
} else if (stage.compareTo(Stage.SORT) > 0 && sort == null) {
// Cannot add sort project without a sort
return false;
} else {
// Looks good.
return true;
@ -277,12 +312,15 @@ public class PartialDruidQuery
*
* @return stage
*/
@SuppressWarnings("VariableNotUsedInsideIf")
public Stage stage()
{
if (sort != null) {
if (sortProject != null) {
return Stage.SORT_PROJECT;
} else if (sort != null) {
return Stage.SORT;
} else if (postProject != null) {
return Stage.POST_PROJECT;
} else if (aggregateProject != null) {
return Stage.AGGREGATE_PROJECT;
} else if (havingFilter != null) {
return Stage.HAVING_FILTER;
} else if (aggregate != null) {
@ -308,10 +346,12 @@ public class PartialDruidQuery
final Stage currentStage = stage();
switch (currentStage) {
case SORT_PROJECT:
return sortProject;
case SORT:
return sort;
case POST_PROJECT:
return postProject;
case AGGREGATE_PROJECT:
return aggregateProject;
case HAVING_FILTER:
return havingFilter;
case AGGREGATE:
@ -352,14 +392,25 @@ public class PartialDruidQuery
Objects.equals(selectSort, that.selectSort) &&
Objects.equals(aggregate, that.aggregate) &&
Objects.equals(havingFilter, that.havingFilter) &&
Objects.equals(postProject, that.postProject) &&
Objects.equals(sort, that.sort);
Objects.equals(aggregateProject, that.aggregateProject) &&
Objects.equals(sort, that.sort) &&
Objects.equals(sortProject, that.sortProject);
}
@Override
public int hashCode()
{
return Objects.hash(scan, whereFilter, selectProject, selectSort, aggregate, havingFilter, postProject, sort);
return Objects.hash(
scan,
whereFilter,
selectProject,
selectSort,
aggregate,
havingFilter,
aggregateProject,
sort,
sortProject
);
}
@Override
@ -372,8 +423,9 @@ public class PartialDruidQuery
", selectSort=" + selectSort +
", aggregate=" + aggregate +
", havingFilter=" + havingFilter +
", postProject=" + postProject +
", aggregateProject=" + aggregateProject +
", sort=" + sort +
", sortProject=" + sortProject +
'}';
}
}

View File

@ -0,0 +1,112 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.sql.calcite.rel;
import com.google.common.base.Preconditions;
import io.druid.java.util.common.ISE;
import io.druid.query.aggregation.PostAggregator;
import io.druid.sql.calcite.table.RowSignature;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
public class SortProject
{
private final RowSignature inputRowSignature;
private final List<PostAggregator> postAggregators;
private final RowSignature outputRowSignature;
SortProject(
RowSignature inputRowSignature,
List<PostAggregator> postAggregators,
RowSignature outputRowSignature
)
{
this.inputRowSignature = Preconditions.checkNotNull(inputRowSignature, "inputRowSignature");
this.postAggregators = Preconditions.checkNotNull(postAggregators, "postAggregators");
this.outputRowSignature = Preconditions.checkNotNull(outputRowSignature, "outputRowSignature");
// Verify no collisions.
final Set<String> seen = new HashSet<>();
inputRowSignature.getRowOrder().forEach(field -> {
if (!seen.add(field)) {
throw new ISE("Duplicate field name: %s", field);
}
});
for (PostAggregator postAggregator : postAggregators) {
if (postAggregator == null) {
throw new ISE("aggregation[%s] is not a postAggregator", postAggregator);
}
if (!seen.add(postAggregator.getName())) {
throw new ISE("Duplicate field name: %s", postAggregator.getName());
}
}
// Verify that items in the output signature exist.
outputRowSignature.getRowOrder().forEach(field -> {
if (!seen.contains(field)) {
throw new ISE("Missing field in rowOrder: %s", field);
}
});
}
public List<PostAggregator> getPostAggregators()
{
return postAggregators;
}
public RowSignature getOutputRowSignature()
{
return outputRowSignature;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SortProject sortProject = (SortProject) o;
return Objects.equals(inputRowSignature, sortProject.inputRowSignature) &&
Objects.equals(postAggregators, sortProject.postAggregators) &&
Objects.equals(outputRowSignature, sortProject.outputRowSignature);
}
@Override
public int hashCode()
{
return Objects.hash(inputRowSignature, postAggregators, outputRowSignature);
}
@Override
public String toString()
{
return "SortProject{" +
"inputRowSignature=" + inputRowSignature +
", postAggregators=" + postAggregators +
", outputRowSignature=" + outputRowSignature +
'}';
}
}

View File

@ -68,8 +68,8 @@ public class DruidRules
),
new DruidQueryRule<>(
Project.class,
PartialDruidQuery.Stage.POST_PROJECT,
PartialDruidQuery::withPostProject
PartialDruidQuery.Stage.AGGREGATE_PROJECT,
PartialDruidQuery::withAggregateProject
),
new DruidQueryRule<>(
Filter.class,
@ -81,10 +81,16 @@ public class DruidRules
PartialDruidQuery.Stage.SORT,
PartialDruidQuery::withSort
),
new DruidQueryRule<>(
Project.class,
PartialDruidQuery.Stage.SORT_PROJECT,
PartialDruidQuery::withSortProject
),
DruidOuterQueryRule.AGGREGATE,
DruidOuterQueryRule.FILTER_AGGREGATE,
DruidOuterQueryRule.FILTER_PROJECT_AGGREGATE,
DruidOuterQueryRule.PROJECT_AGGREGATE
DruidOuterQueryRule.PROJECT_AGGREGATE,
DruidOuterQueryRule.AGGREGATE_SORT_PROJECT
);
}
@ -227,6 +233,32 @@ public class DruidRules
}
};
public static RelOptRule AGGREGATE_SORT_PROJECT = new DruidOuterQueryRule(
operand(Project.class, operand(Sort.class, operand(Aggregate.class, operand(DruidRel.class, any())))),
"AGGREGATE_SORT_PROJECT"
)
{
@Override
public void onMatch(RelOptRuleCall call)
{
final Project sortProject = call.rel(0);
final Sort sort = call.rel(1);
final Aggregate aggregate = call.rel(2);
final DruidRel druidRel = call.rel(3);
final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
druidRel,
PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
.withAggregate(aggregate)
.withSort(sort)
.withSortProject(sortProject)
);
if (outerQueryRel.isValidDruidQuery()) {
call.transformTo(outerQueryRel);
}
}
};
public DruidOuterQueryRule(final RelOptRuleOperand op, final String description)
{
super(op, StringUtils.format("%s:%s", DruidOuterQueryRel.class.getSimpleName(), description));

View File

@ -24,6 +24,7 @@ import com.google.common.base.Predicates;
import io.druid.sql.calcite.planner.PlannerConfig;
import io.druid.sql.calcite.rel.DruidRel;
import io.druid.sql.calcite.rel.DruidSemiJoin;
import io.druid.sql.calcite.rel.PartialDruidQuery;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
@ -115,15 +116,18 @@ public class DruidSemiJoinRule extends RelOptRule
return;
}
final Project rightPostProject = right.getPartialDruidQuery().getPostProject();
final PartialDruidQuery rightQuery = right.getPartialDruidQuery();
final Project rightProject = rightQuery.getSortProject() != null ?
rightQuery.getSortProject() :
rightQuery.getAggregateProject();
int i = 0;
for (int joinRef : joinInfo.rightSet()) {
final int aggregateRef;
if (rightPostProject == null) {
if (rightProject == null) {
aggregateRef = joinRef;
} else {
final RexNode projectExp = rightPostProject.getChildExps().get(joinRef);
final RexNode projectExp = rightProject.getChildExps().get(joinRef);
if (projectExp.isA(SqlKind.INPUT_REF)) {
aggregateRef = ((RexInputRef) projectExp).getIndex();
} else {

View File

@ -70,6 +70,7 @@ import io.druid.query.groupby.GroupByQuery;
import io.druid.query.groupby.having.DimFilterHavingSpec;
import io.druid.query.groupby.orderby.DefaultLimitSpec;
import io.druid.query.groupby.orderby.OrderByColumnSpec;
import io.druid.query.groupby.orderby.OrderByColumnSpec.Direction;
import io.druid.query.lookup.RegisteredLookupExtractionFn;
import io.druid.query.ordering.StringComparator;
import io.druid.query.ordering.StringComparators;
@ -119,6 +120,7 @@ import org.junit.rules.TemporaryFolder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -6529,6 +6531,193 @@ public class CalciteQueryTest extends CalciteTestBase
);
}
@Test
public void testProjectAfterSort() throws Exception
{
testQuery(
"select dim1 from (select dim1, dim2, count(*) cnt from druid.foo group by dim1, dim2 order by cnt)",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(
DIMS(
new DefaultDimensionSpec("dim1", "d0"),
new DefaultDimensionSpec("dim2", "d1")
)
)
.setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0")))
.setLimitSpec(
new DefaultLimitSpec(
Collections.singletonList(
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
),
Integer.MAX_VALUE
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{""},
new Object[]{"1"},
new Object[]{"10.1"},
new Object[]{"2"},
new Object[]{"abc"},
new Object[]{"def"}
)
);
}
@Test
public void testProjectAfterSort2() throws Exception
{
testQuery(
"select s / cnt, dim1, dim2, s from (select dim1, dim2, count(*) cnt, sum(m2) s from druid.foo group by dim1, dim2 order by cnt)",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(
DIMS(
new DefaultDimensionSpec("dim1", "d0"),
new DefaultDimensionSpec("dim2", "d1")
)
)
.setAggregatorSpecs(
AGGS(new CountAggregatorFactory("a0"), new DoubleSumAggregatorFactory("a1", "m2"))
)
.setPostAggregatorSpecs(Collections.singletonList(EXPRESSION_POST_AGG("p0", "(\"a1\" / \"a0\")")))
.setLimitSpec(
new DefaultLimitSpec(
Collections.singletonList(
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
),
Integer.MAX_VALUE
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1.0, "", "a", 1.0},
new Object[]{4.0, "1", "a", 4.0},
new Object[]{2.0, "10.1", "", 2.0},
new Object[]{3.0, "2", "", 3.0},
new Object[]{6.0, "abc", "", 6.0},
new Object[]{5.0, "def", "abc", 5.0}
)
);
}
@Test
public void testProjectAfterSort3() throws Exception
{
testQuery(
"select dim1 from (select dim1, dim1, count(*) cnt from druid.foo group by dim1, dim1 order by cnt)",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(
DIMS(
new DefaultDimensionSpec("dim1", "d0")
)
)
.setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0")))
.setLimitSpec(
new DefaultLimitSpec(
Collections.singletonList(
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
),
Integer.MAX_VALUE
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{""},
new Object[]{"1"},
new Object[]{"10.1"},
new Object[]{"2"},
new Object[]{"abc"},
new Object[]{"def"}
)
);
}
@Test
public void testSortProjectAfterNestedGroupBy() throws Exception
{
testQuery(
"SELECT "
+ " cnt "
+ "FROM ("
+ " SELECT "
+ " __time, "
+ " dim1, "
+ " COUNT(m2) AS cnt "
+ " FROM ("
+ " SELECT "
+ " __time, "
+ " m2, "
+ " dim1 "
+ " FROM druid.foo "
+ " GROUP BY __time, m2, dim1 "
+ " ) "
+ " GROUP BY __time, dim1 "
+ " ORDER BY cnt"
+ ")",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(DIMS(
new DefaultDimensionSpec("__time", "d0", ValueType.LONG),
new DefaultDimensionSpec("dim1", "d1"),
new DefaultDimensionSpec("m2", "d2", ValueType.DOUBLE)
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(DIMS(
new DefaultDimensionSpec("d0", "_d0", ValueType.LONG),
new DefaultDimensionSpec("d1", "_d1", ValueType.STRING)
))
.setAggregatorSpecs(AGGS(
new CountAggregatorFactory("a0")
))
.setLimitSpec(
new DefaultLimitSpec(
Collections.singletonList(
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
),
Integer.MAX_VALUE
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1L},
new Object[]{1L},
new Object[]{1L},
new Object[]{1L},
new Object[]{1L},
new Object[]{1L}
)
);
}
private void testQuery(
final String sql,
final List<Query> expectedQueries,