Support projection after sorting in SQL (#5788)

* Add sort project

* add more test

* address comments
This commit is contained in:
Jihoon Son 2018-06-11 08:33:47 -07:00 committed by Gian Merlino
parent e43e5ebbcd
commit fe4d678aac
9 changed files with 582 additions and 98 deletions

View File

@ -36,6 +36,7 @@ import io.druid.sql.calcite.filtration.Filtration;
import io.druid.sql.calcite.table.RowSignature; import io.druid.sql.calcite.table.RowSignature;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
@ -112,7 +113,7 @@ public class Aggregation
public static Aggregation create(final PostAggregator postAggregator) public static Aggregation create(final PostAggregator postAggregator)
{ {
return new Aggregation(ImmutableList.of(), ImmutableList.of(), postAggregator); return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
} }
public static Aggregation create( public static Aggregation create(

View File

@ -89,6 +89,7 @@ import javax.annotation.Nullable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.OptionalInt;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -105,9 +106,11 @@ public class DruidQuery
private final DimFilter filter; private final DimFilter filter;
private final SelectProjection selectProjection; private final SelectProjection selectProjection;
private final Grouping grouping; private final Grouping grouping;
private final SortProject sortProject;
private final DefaultLimitSpec limitSpec;
private final RowSignature outputRowSignature; private final RowSignature outputRowSignature;
private final RelDataType outputRowType; private final RelDataType outputRowType;
private final DefaultLimitSpec limitSpec;
private final Query query; private final Query query;
public DruidQuery( public DruidQuery(
@ -128,15 +131,22 @@ public class DruidQuery
this.selectProjection = computeSelectProjection(partialQuery, plannerContext, sourceRowSignature); this.selectProjection = computeSelectProjection(partialQuery, plannerContext, sourceRowSignature);
this.grouping = computeGrouping(partialQuery, plannerContext, sourceRowSignature, rexBuilder); this.grouping = computeGrouping(partialQuery, plannerContext, sourceRowSignature, rexBuilder);
final RowSignature sortingInputRowSignature;
if (this.selectProjection != null) { if (this.selectProjection != null) {
this.outputRowSignature = this.selectProjection.getOutputRowSignature(); sortingInputRowSignature = this.selectProjection.getOutputRowSignature();
} else if (this.grouping != null) { } else if (this.grouping != null) {
this.outputRowSignature = this.grouping.getOutputRowSignature(); sortingInputRowSignature = this.grouping.getOutputRowSignature();
} else { } else {
this.outputRowSignature = sourceRowSignature; sortingInputRowSignature = sourceRowSignature;
} }
this.limitSpec = computeLimitSpec(partialQuery, this.outputRowSignature); this.sortProject = computeSortProject(partialQuery, plannerContext, sortingInputRowSignature, grouping);
// outputRowSignature is used only for scan and select query, and thus sort and grouping must be null
this.outputRowSignature = sortProject == null ? sortingInputRowSignature : sortProject.getOutputRowSignature();
this.limitSpec = computeLimitSpec(partialQuery, sortingInputRowSignature);
this.query = computeQuery(); this.query = computeQuery();
} }
@ -235,7 +245,7 @@ public class DruidQuery
) )
{ {
final Aggregate aggregate = partialQuery.getAggregate(); final Aggregate aggregate = partialQuery.getAggregate();
final Project postProject = partialQuery.getPostProject(); final Project aggregateProject = partialQuery.getAggregateProject();
if (aggregate == null) { if (aggregate == null) {
return null; return null;
@ -265,25 +275,108 @@ public class DruidQuery
plannerContext plannerContext
); );
if (postProject == null) { if (aggregateProject == null) {
return Grouping.create(dimensions, aggregations, havingFilter, aggregateRowSignature); return Grouping.create(dimensions, aggregations, havingFilter, aggregateRowSignature);
} else { } else {
final List<String> rowOrder = new ArrayList<>(); final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
plannerContext,
aggregateRowSignature,
aggregateProject,
0
);
projectRowOrderAndPostAggregations.postAggregations.forEach(
postAggregator -> aggregations.add(Aggregation.create(postAggregator))
);
int outputNameCounter = 0; // Remove literal dimensions that did not appear in the projection. This is useful for queries
for (final RexNode postAggregatorRexNode : postProject.getChildExps()) { // like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can generate, and for which we don't
// actually want to include a dimension 'dummy'.
final ImmutableBitSet aggregateProjectBits = RelOptUtil.InputFinder.bits(aggregateProject.getChildExps(), null);
for (int i = dimensions.size() - 1; i >= 0; i--) {
final DimensionExpression dimension = dimensions.get(i);
if (Parser.parse(dimension.getDruidExpression().getExpression(), plannerContext.getExprMacroTable())
.isLiteral() && !aggregateProjectBits.get(i)) {
dimensions.remove(i);
}
}
return Grouping.create(
dimensions,
aggregations,
havingFilter,
RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, aggregateProject.getRowType())
);
}
}
@Nullable
private SortProject computeSortProject(
PartialDruidQuery partialQuery,
PlannerContext plannerContext,
RowSignature sortingInputRowSignature,
Grouping grouping
)
{
final Project sortProject = partialQuery.getSortProject();
if (sortProject == null) {
return null;
} else {
final List<PostAggregator> postAggregators = grouping.getPostAggregators();
final OptionalInt maybeMaxCounter = postAggregators
.stream()
.mapToInt(postAggregator -> Integer.parseInt(postAggregator.getName().substring(1)))
.max();
final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
plannerContext,
sortingInputRowSignature,
sortProject,
maybeMaxCounter.orElse(-1) + 1 // 0 if max doesn't exist
);
return new SortProject(
sortingInputRowSignature,
projectRowOrderAndPostAggregations.postAggregations,
RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, sortProject.getRowType())
);
}
}
private static class ProjectRowOrderAndPostAggregations
{
private final List<String> rowOrder;
private final List<PostAggregator> postAggregations;
ProjectRowOrderAndPostAggregations(List<String> rowOrder, List<PostAggregator> postAggregations)
{
this.rowOrder = rowOrder;
this.postAggregations = postAggregations;
}
}
private static ProjectRowOrderAndPostAggregations computePostAggregations(
PlannerContext plannerContext,
RowSignature inputRowSignature,
Project project,
int outputNameCounter
)
{
final List<String> rowOrder = new ArrayList<>();
final List<PostAggregator> aggregations = new ArrayList<>();
for (final RexNode postAggregatorRexNode : project.getChildExps()) {
// Attempt to convert to PostAggregator. // Attempt to convert to PostAggregator.
final DruidExpression postAggregatorExpression = Expressions.toDruidExpression( final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
plannerContext, plannerContext,
aggregateRowSignature, inputRowSignature,
postAggregatorRexNode postAggregatorRexNode
); );
if (postAggregatorExpression == null) { if (postAggregatorExpression == null) {
throw new CannotBuildQueryException(postProject, postAggregatorRexNode); throw new CannotBuildQueryException(project, postAggregatorRexNode);
} }
if (postAggregatorDirectColumnIsOk(aggregateRowSignature, postAggregatorExpression, postAggregatorRexNode)) { if (postAggregatorDirectColumnIsOk(inputRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
// Direct column access, without any type cast as far as Druid's runtime is concerned. // Direct column access, without any type cast as far as Druid's runtime is concerned.
// (There might be a SQL-level type cast that we don't care about) // (There might be a SQL-level type cast that we don't care about)
rowOrder.add(postAggregatorExpression.getDirectColumn()); rowOrder.add(postAggregatorExpression.getDirectColumn());
@ -295,30 +388,12 @@ public class DruidQuery
null, null,
plannerContext.getExprMacroTable() plannerContext.getExprMacroTable()
); );
aggregations.add(Aggregation.create(postAggregator)); aggregations.add(postAggregator);
rowOrder.add(postAggregator.getName()); rowOrder.add(postAggregator.getName());
} }
} }
// Remove literal dimensions that did not appear in the projection. This is useful for queries return new ProjectRowOrderAndPostAggregations(rowOrder, aggregations);
// like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can generate, and for which we don't
// actually want to include a dimension 'dummy'.
final ImmutableBitSet postProjectBits = RelOptUtil.InputFinder.bits(postProject.getChildExps(), null);
for (int i = dimensions.size() - 1; i >= 0; i--) {
final DimensionExpression dimension = dimensions.get(i);
if (Parser.parse(dimension.getDruidExpression().getExpression(), plannerContext.getExprMacroTable())
.isLiteral() && !postProjectBits.get(i)) {
dimensions.remove(i);
}
}
return Grouping.create(
dimensions,
aggregations,
havingFilter,
RowSignature.from(rowOrder, postProject.getRowType())
);
}
} }
/** /**
@ -540,6 +615,9 @@ public class DruidQuery
{ {
final List<VirtualColumn> retVal = new ArrayList<>(); final List<VirtualColumn> retVal = new ArrayList<>();
if (selectProjection != null) {
retVal.addAll(selectProjection.getVirtualColumns());
} else {
if (grouping != null) { if (grouping != null) {
if (includeDimensions) { if (includeDimensions) {
for (DimensionExpression dimensionExpression : grouping.getDimensions()) { for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
@ -550,8 +628,7 @@ public class DruidQuery
for (Aggregation aggregation : grouping.getAggregations()) { for (Aggregation aggregation : grouping.getAggregations()) {
retVal.addAll(aggregation.getVirtualColumns()); retVal.addAll(aggregation.getVirtualColumns());
} }
} else if (selectProjection != null) { }
retVal.addAll(selectProjection.getVirtualColumns());
} }
return VirtualColumns.create(retVal); return VirtualColumns.create(retVal);
@ -567,6 +644,11 @@ public class DruidQuery
return limitSpec; return limitSpec;
} }
public SortProject getSortProject()
{
return sortProject;
}
public RelDataType getOutputRowType() public RelDataType getOutputRowType()
{ {
return outputRowType; return outputRowType;
@ -667,7 +749,6 @@ public class DruidQuery
if (limitSpec != null) { if (limitSpec != null) {
// If there is a limit spec, timeseries cannot LIMIT; and must be ORDER BY time (or nothing). // If there is a limit spec, timeseries cannot LIMIT; and must be ORDER BY time (or nothing).
if (limitSpec.isLimited()) { if (limitSpec.isLimited()) {
return null; return null;
} }
@ -797,6 +878,11 @@ public class DruidQuery
final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature); final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature);
final List<PostAggregator> postAggregators = new ArrayList<>(grouping.getPostAggregators());
if (sortProject != null) {
postAggregators.addAll(sortProject.getPostAggregators());
}
return new GroupByQuery( return new GroupByQuery(
dataSource, dataSource,
filtration.getQuerySegmentSpec(), filtration.getQuerySegmentSpec(),
@ -805,7 +891,7 @@ public class DruidQuery
Granularities.ALL, Granularities.ALL,
grouping.getDimensionSpecs(), grouping.getDimensionSpecs(),
grouping.getAggregatorFactories(), grouping.getAggregatorFactories(),
grouping.getPostAggregators(), postAggregators,
grouping.getHavingFilter() != null ? new DimFilterHavingSpec(grouping.getHavingFilter(), true) : null, grouping.getHavingFilter() != null ? new DimFilterHavingSpec(grouping.getHavingFilter(), true) : null,
limitSpec, limitSpec,
ImmutableSortedMap.copyOf(plannerContext.getQueryContext()) ImmutableSortedMap.copyOf(plannerContext.getQueryContext())

View File

@ -220,14 +220,18 @@ public class DruidQueryRel extends DruidRel<DruidQueryRel>
cost += COST_PER_COLUMN * partialQuery.getAggregate().getAggCallList().size(); cost += COST_PER_COLUMN * partialQuery.getAggregate().getAggCallList().size();
} }
if (partialQuery.getPostProject() != null) { if (partialQuery.getAggregateProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getPostProject().getChildExps().size(); cost += COST_PER_COLUMN * partialQuery.getAggregateProject().getChildExps().size();
} }
if (partialQuery.getSort() != null && partialQuery.getSort().fetch != null) { if (partialQuery.getSort() != null && partialQuery.getSort().fetch != null) {
cost *= COST_LIMIT_MULTIPLIER; cost *= COST_LIMIT_MULTIPLIER;
} }
if (partialQuery.getSortProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getSortProject().getChildExps().size();
}
if (partialQuery.getHavingFilter() != null) { if (partialQuery.getHavingFilter() != null) {
cost *= COST_HAVING_MULTIPLIER; cost *= COST_HAVING_MULTIPLIER;
} }

View File

@ -358,8 +358,12 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
newPartialQuery = newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter()); newPartialQuery = newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter());
} }
if (leftPartialQuery.getPostProject() != null) { if (leftPartialQuery.getAggregateProject() != null) {
newPartialQuery = newPartialQuery.withPostProject(leftPartialQuery.getPostProject()); newPartialQuery = newPartialQuery.withAggregateProject(leftPartialQuery.getAggregateProject());
}
if (leftPartialQuery.getSortProject() != null) {
newPartialQuery = newPartialQuery.withSortProject(leftPartialQuery.getSortProject());
} }
if (leftPartialQuery.getSort() != null) { if (leftPartialQuery.getSort() != null) {

View File

@ -46,8 +46,9 @@ public class PartialDruidQuery
private final Sort selectSort; private final Sort selectSort;
private final Aggregate aggregate; private final Aggregate aggregate;
private final Filter havingFilter; private final Filter havingFilter;
private final Project postProject; private final Project aggregateProject;
private final Sort sort; private final Sort sort;
private final Project sortProject;
public enum Stage public enum Stage
{ {
@ -57,8 +58,9 @@ public class PartialDruidQuery
SELECT_SORT, SELECT_SORT,
AGGREGATE, AGGREGATE,
HAVING_FILTER, HAVING_FILTER,
POST_PROJECT, AGGREGATE_PROJECT,
SORT SORT,
SORT_PROJECT
} }
public PartialDruidQuery( public PartialDruidQuery(
@ -67,9 +69,10 @@ public class PartialDruidQuery
final Project selectProject, final Project selectProject,
final Sort selectSort, final Sort selectSort,
final Aggregate aggregate, final Aggregate aggregate,
final Project postProject, final Project aggregateProject,
final Filter havingFilter, final Filter havingFilter,
final Sort sort final Sort sort,
final Project sortProject
) )
{ {
this.scan = Preconditions.checkNotNull(scan, "scan"); this.scan = Preconditions.checkNotNull(scan, "scan");
@ -77,14 +80,15 @@ public class PartialDruidQuery
this.selectProject = selectProject; this.selectProject = selectProject;
this.selectSort = selectSort; this.selectSort = selectSort;
this.aggregate = aggregate; this.aggregate = aggregate;
this.postProject = postProject; this.aggregateProject = aggregateProject;
this.havingFilter = havingFilter; this.havingFilter = havingFilter;
this.sort = sort; this.sort = sort;
this.sortProject = sortProject;
} }
public static PartialDruidQuery create(final RelNode scanRel) public static PartialDruidQuery create(final RelNode scanRel)
{ {
return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null); return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null, null);
} }
public RelNode getScan() public RelNode getScan()
@ -117,9 +121,9 @@ public class PartialDruidQuery
return havingFilter; return havingFilter;
} }
public Project getPostProject() public Project getAggregateProject()
{ {
return postProject; return aggregateProject;
} }
public Sort getSort() public Sort getSort()
@ -127,6 +131,11 @@ public class PartialDruidQuery
return sort; return sort;
} }
public Project getSortProject()
{
return sortProject;
}
public PartialDruidQuery withWhereFilter(final Filter newWhereFilter) public PartialDruidQuery withWhereFilter(final Filter newWhereFilter)
{ {
validateStage(Stage.WHERE_FILTER); validateStage(Stage.WHERE_FILTER);
@ -136,9 +145,10 @@ public class PartialDruidQuery
selectProject, selectProject,
selectSort, selectSort,
aggregate, aggregate,
postProject, aggregateProject,
havingFilter, havingFilter,
sort sort,
sortProject
); );
} }
@ -151,9 +161,10 @@ public class PartialDruidQuery
newSelectProject, newSelectProject,
selectSort, selectSort,
aggregate, aggregate,
postProject, aggregateProject,
havingFilter, havingFilter,
sort sort,
sortProject
); );
} }
@ -166,9 +177,10 @@ public class PartialDruidQuery
selectProject, selectProject,
newSelectSort, newSelectSort,
aggregate, aggregate,
postProject, aggregateProject,
havingFilter, havingFilter,
sort sort,
sortProject
); );
} }
@ -181,9 +193,10 @@ public class PartialDruidQuery
selectProject, selectProject,
selectSort, selectSort,
newAggregate, newAggregate,
postProject, aggregateProject,
havingFilter, havingFilter,
sort sort,
sortProject
); );
} }
@ -196,24 +209,26 @@ public class PartialDruidQuery
selectProject, selectProject,
selectSort, selectSort,
aggregate, aggregate,
postProject, aggregateProject,
newHavingFilter, newHavingFilter,
sort sort,
sortProject
); );
} }
public PartialDruidQuery withPostProject(final Project newPostProject) public PartialDruidQuery withAggregateProject(final Project newAggregateProject)
{ {
validateStage(Stage.POST_PROJECT); validateStage(Stage.AGGREGATE_PROJECT);
return new PartialDruidQuery( return new PartialDruidQuery(
scan, scan,
whereFilter, whereFilter,
selectProject, selectProject,
selectSort, selectSort,
aggregate, aggregate,
newPostProject, newAggregateProject,
havingFilter, havingFilter,
sort sort,
sortProject
); );
} }
@ -226,9 +241,26 @@ public class PartialDruidQuery
selectProject, selectProject,
selectSort, selectSort,
aggregate, aggregate,
postProject, aggregateProject,
havingFilter, havingFilter,
newSort newSort,
sortProject
);
}
public PartialDruidQuery withSortProject(final Project newSortProject)
{
validateStage(Stage.SORT_PROJECT);
return new PartialDruidQuery(
scan,
whereFilter,
selectProject,
selectSort,
aggregate,
aggregateProject,
havingFilter,
sort,
newSortProject
); );
} }
@ -265,6 +297,9 @@ public class PartialDruidQuery
} else if (stage.compareTo(Stage.AGGREGATE) >= 0 && selectSort != null) { } else if (stage.compareTo(Stage.AGGREGATE) >= 0 && selectSort != null) {
// Cannot do any aggregations after a select + sort. // Cannot do any aggregations after a select + sort.
return false; return false;
} else if (stage.compareTo(Stage.SORT) > 0 && sort == null) {
// Cannot add sort project without a sort
return false;
} else { } else {
// Looks good. // Looks good.
return true; return true;
@ -277,12 +312,15 @@ public class PartialDruidQuery
* *
* @return stage * @return stage
*/ */
@SuppressWarnings("VariableNotUsedInsideIf")
public Stage stage() public Stage stage()
{ {
if (sort != null) { if (sortProject != null) {
return Stage.SORT_PROJECT;
} else if (sort != null) {
return Stage.SORT; return Stage.SORT;
} else if (postProject != null) { } else if (aggregateProject != null) {
return Stage.POST_PROJECT; return Stage.AGGREGATE_PROJECT;
} else if (havingFilter != null) { } else if (havingFilter != null) {
return Stage.HAVING_FILTER; return Stage.HAVING_FILTER;
} else if (aggregate != null) { } else if (aggregate != null) {
@ -308,10 +346,12 @@ public class PartialDruidQuery
final Stage currentStage = stage(); final Stage currentStage = stage();
switch (currentStage) { switch (currentStage) {
case SORT_PROJECT:
return sortProject;
case SORT: case SORT:
return sort; return sort;
case POST_PROJECT: case AGGREGATE_PROJECT:
return postProject; return aggregateProject;
case HAVING_FILTER: case HAVING_FILTER:
return havingFilter; return havingFilter;
case AGGREGATE: case AGGREGATE:
@ -352,14 +392,25 @@ public class PartialDruidQuery
Objects.equals(selectSort, that.selectSort) && Objects.equals(selectSort, that.selectSort) &&
Objects.equals(aggregate, that.aggregate) && Objects.equals(aggregate, that.aggregate) &&
Objects.equals(havingFilter, that.havingFilter) && Objects.equals(havingFilter, that.havingFilter) &&
Objects.equals(postProject, that.postProject) && Objects.equals(aggregateProject, that.aggregateProject) &&
Objects.equals(sort, that.sort); Objects.equals(sort, that.sort) &&
Objects.equals(sortProject, that.sortProject);
} }
@Override @Override
public int hashCode() public int hashCode()
{ {
return Objects.hash(scan, whereFilter, selectProject, selectSort, aggregate, havingFilter, postProject, sort); return Objects.hash(
scan,
whereFilter,
selectProject,
selectSort,
aggregate,
havingFilter,
aggregateProject,
sort,
sortProject
);
} }
@Override @Override
@ -372,8 +423,9 @@ public class PartialDruidQuery
", selectSort=" + selectSort + ", selectSort=" + selectSort +
", aggregate=" + aggregate + ", aggregate=" + aggregate +
", havingFilter=" + havingFilter + ", havingFilter=" + havingFilter +
", postProject=" + postProject + ", aggregateProject=" + aggregateProject +
", sort=" + sort + ", sort=" + sort +
", sortProject=" + sortProject +
'}'; '}';
} }
} }

View File

@ -0,0 +1,112 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.sql.calcite.rel;
import com.google.common.base.Preconditions;
import io.druid.java.util.common.ISE;
import io.druid.query.aggregation.PostAggregator;
import io.druid.sql.calcite.table.RowSignature;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
public class SortProject
{
private final RowSignature inputRowSignature;
private final List<PostAggregator> postAggregators;
private final RowSignature outputRowSignature;
SortProject(
RowSignature inputRowSignature,
List<PostAggregator> postAggregators,
RowSignature outputRowSignature
)
{
this.inputRowSignature = Preconditions.checkNotNull(inputRowSignature, "inputRowSignature");
this.postAggregators = Preconditions.checkNotNull(postAggregators, "postAggregators");
this.outputRowSignature = Preconditions.checkNotNull(outputRowSignature, "outputRowSignature");
// Verify no collisions.
final Set<String> seen = new HashSet<>();
inputRowSignature.getRowOrder().forEach(field -> {
if (!seen.add(field)) {
throw new ISE("Duplicate field name: %s", field);
}
});
for (PostAggregator postAggregator : postAggregators) {
if (postAggregator == null) {
throw new ISE("aggregation[%s] is not a postAggregator", postAggregator);
}
if (!seen.add(postAggregator.getName())) {
throw new ISE("Duplicate field name: %s", postAggregator.getName());
}
}
// Verify that items in the output signature exist.
outputRowSignature.getRowOrder().forEach(field -> {
if (!seen.contains(field)) {
throw new ISE("Missing field in rowOrder: %s", field);
}
});
}
public List<PostAggregator> getPostAggregators()
{
return postAggregators;
}
public RowSignature getOutputRowSignature()
{
return outputRowSignature;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SortProject sortProject = (SortProject) o;
return Objects.equals(inputRowSignature, sortProject.inputRowSignature) &&
Objects.equals(postAggregators, sortProject.postAggregators) &&
Objects.equals(outputRowSignature, sortProject.outputRowSignature);
}
@Override
public int hashCode()
{
return Objects.hash(inputRowSignature, postAggregators, outputRowSignature);
}
@Override
public String toString()
{
return "SortProject{" +
"inputRowSignature=" + inputRowSignature +
", postAggregators=" + postAggregators +
", outputRowSignature=" + outputRowSignature +
'}';
}
}

View File

@ -68,8 +68,8 @@ public class DruidRules
), ),
new DruidQueryRule<>( new DruidQueryRule<>(
Project.class, Project.class,
PartialDruidQuery.Stage.POST_PROJECT, PartialDruidQuery.Stage.AGGREGATE_PROJECT,
PartialDruidQuery::withPostProject PartialDruidQuery::withAggregateProject
), ),
new DruidQueryRule<>( new DruidQueryRule<>(
Filter.class, Filter.class,
@ -81,10 +81,16 @@ public class DruidRules
PartialDruidQuery.Stage.SORT, PartialDruidQuery.Stage.SORT,
PartialDruidQuery::withSort PartialDruidQuery::withSort
), ),
new DruidQueryRule<>(
Project.class,
PartialDruidQuery.Stage.SORT_PROJECT,
PartialDruidQuery::withSortProject
),
DruidOuterQueryRule.AGGREGATE, DruidOuterQueryRule.AGGREGATE,
DruidOuterQueryRule.FILTER_AGGREGATE, DruidOuterQueryRule.FILTER_AGGREGATE,
DruidOuterQueryRule.FILTER_PROJECT_AGGREGATE, DruidOuterQueryRule.FILTER_PROJECT_AGGREGATE,
DruidOuterQueryRule.PROJECT_AGGREGATE DruidOuterQueryRule.PROJECT_AGGREGATE,
DruidOuterQueryRule.AGGREGATE_SORT_PROJECT
); );
} }
@ -227,6 +233,32 @@ public class DruidRules
} }
}; };
public static RelOptRule AGGREGATE_SORT_PROJECT = new DruidOuterQueryRule(
operand(Project.class, operand(Sort.class, operand(Aggregate.class, operand(DruidRel.class, any())))),
"AGGREGATE_SORT_PROJECT"
)
{
@Override
public void onMatch(RelOptRuleCall call)
{
final Project sortProject = call.rel(0);
final Sort sort = call.rel(1);
final Aggregate aggregate = call.rel(2);
final DruidRel druidRel = call.rel(3);
final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
druidRel,
PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
.withAggregate(aggregate)
.withSort(sort)
.withSortProject(sortProject)
);
if (outerQueryRel.isValidDruidQuery()) {
call.transformTo(outerQueryRel);
}
}
};
public DruidOuterQueryRule(final RelOptRuleOperand op, final String description) public DruidOuterQueryRule(final RelOptRuleOperand op, final String description)
{ {
super(op, StringUtils.format("%s:%s", DruidOuterQueryRel.class.getSimpleName(), description)); super(op, StringUtils.format("%s:%s", DruidOuterQueryRel.class.getSimpleName(), description));

View File

@ -24,6 +24,7 @@ import com.google.common.base.Predicates;
import io.druid.sql.calcite.planner.PlannerConfig; import io.druid.sql.calcite.planner.PlannerConfig;
import io.druid.sql.calcite.rel.DruidRel; import io.druid.sql.calcite.rel.DruidRel;
import io.druid.sql.calcite.rel.DruidSemiJoin; import io.druid.sql.calcite.rel.DruidSemiJoin;
import io.druid.sql.calcite.rel.PartialDruidQuery;
import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelOptUtil;
@ -115,15 +116,18 @@ public class DruidSemiJoinRule extends RelOptRule
return; return;
} }
final Project rightPostProject = right.getPartialDruidQuery().getPostProject(); final PartialDruidQuery rightQuery = right.getPartialDruidQuery();
final Project rightProject = rightQuery.getSortProject() != null ?
rightQuery.getSortProject() :
rightQuery.getAggregateProject();
int i = 0; int i = 0;
for (int joinRef : joinInfo.rightSet()) { for (int joinRef : joinInfo.rightSet()) {
final int aggregateRef; final int aggregateRef;
if (rightPostProject == null) { if (rightProject == null) {
aggregateRef = joinRef; aggregateRef = joinRef;
} else { } else {
final RexNode projectExp = rightPostProject.getChildExps().get(joinRef); final RexNode projectExp = rightProject.getChildExps().get(joinRef);
if (projectExp.isA(SqlKind.INPUT_REF)) { if (projectExp.isA(SqlKind.INPUT_REF)) {
aggregateRef = ((RexInputRef) projectExp).getIndex(); aggregateRef = ((RexInputRef) projectExp).getIndex();
} else { } else {

View File

@ -70,6 +70,7 @@ import io.druid.query.groupby.GroupByQuery;
import io.druid.query.groupby.having.DimFilterHavingSpec; import io.druid.query.groupby.having.DimFilterHavingSpec;
import io.druid.query.groupby.orderby.DefaultLimitSpec; import io.druid.query.groupby.orderby.DefaultLimitSpec;
import io.druid.query.groupby.orderby.OrderByColumnSpec; import io.druid.query.groupby.orderby.OrderByColumnSpec;
import io.druid.query.groupby.orderby.OrderByColumnSpec.Direction;
import io.druid.query.lookup.RegisteredLookupExtractionFn; import io.druid.query.lookup.RegisteredLookupExtractionFn;
import io.druid.query.ordering.StringComparator; import io.druid.query.ordering.StringComparator;
import io.druid.query.ordering.StringComparators; import io.druid.query.ordering.StringComparators;
@ -119,6 +120,7 @@ import org.junit.rules.TemporaryFolder;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -6529,6 +6531,193 @@ public class CalciteQueryTest extends CalciteTestBase
); );
} }
@Test
public void testProjectAfterSort() throws Exception
{
testQuery(
"select dim1 from (select dim1, dim2, count(*) cnt from druid.foo group by dim1, dim2 order by cnt)",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(
DIMS(
new DefaultDimensionSpec("dim1", "d0"),
new DefaultDimensionSpec("dim2", "d1")
)
)
.setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0")))
.setLimitSpec(
new DefaultLimitSpec(
Collections.singletonList(
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
),
Integer.MAX_VALUE
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{""},
new Object[]{"1"},
new Object[]{"10.1"},
new Object[]{"2"},
new Object[]{"abc"},
new Object[]{"def"}
)
);
}
@Test
public void testProjectAfterSort2() throws Exception
{
testQuery(
"select s / cnt, dim1, dim2, s from (select dim1, dim2, count(*) cnt, sum(m2) s from druid.foo group by dim1, dim2 order by cnt)",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(
DIMS(
new DefaultDimensionSpec("dim1", "d0"),
new DefaultDimensionSpec("dim2", "d1")
)
)
.setAggregatorSpecs(
AGGS(new CountAggregatorFactory("a0"), new DoubleSumAggregatorFactory("a1", "m2"))
)
.setPostAggregatorSpecs(Collections.singletonList(EXPRESSION_POST_AGG("p0", "(\"a1\" / \"a0\")")))
.setLimitSpec(
new DefaultLimitSpec(
Collections.singletonList(
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
),
Integer.MAX_VALUE
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1.0, "", "a", 1.0},
new Object[]{4.0, "1", "a", 4.0},
new Object[]{2.0, "10.1", "", 2.0},
new Object[]{3.0, "2", "", 3.0},
new Object[]{6.0, "abc", "", 6.0},
new Object[]{5.0, "def", "abc", 5.0}
)
);
}
@Test
public void testProjectAfterSort3() throws Exception
{
testQuery(
"select dim1 from (select dim1, dim1, count(*) cnt from druid.foo group by dim1, dim1 order by cnt)",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(
DIMS(
new DefaultDimensionSpec("dim1", "d0")
)
)
.setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0")))
.setLimitSpec(
new DefaultLimitSpec(
Collections.singletonList(
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
),
Integer.MAX_VALUE
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{""},
new Object[]{"1"},
new Object[]{"10.1"},
new Object[]{"2"},
new Object[]{"abc"},
new Object[]{"def"}
)
);
}
@Test
public void testSortProjectAfterNestedGroupBy() throws Exception
{
testQuery(
"SELECT "
+ " cnt "
+ "FROM ("
+ " SELECT "
+ " __time, "
+ " dim1, "
+ " COUNT(m2) AS cnt "
+ " FROM ("
+ " SELECT "
+ " __time, "
+ " m2, "
+ " dim1 "
+ " FROM druid.foo "
+ " GROUP BY __time, m2, dim1 "
+ " ) "
+ " GROUP BY __time, dim1 "
+ " ORDER BY cnt"
+ ")",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(DIMS(
new DefaultDimensionSpec("__time", "d0", ValueType.LONG),
new DefaultDimensionSpec("dim1", "d1"),
new DefaultDimensionSpec("m2", "d2", ValueType.DOUBLE)
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
)
.setInterval(QSS(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(DIMS(
new DefaultDimensionSpec("d0", "_d0", ValueType.LONG),
new DefaultDimensionSpec("d1", "_d1", ValueType.STRING)
))
.setAggregatorSpecs(AGGS(
new CountAggregatorFactory("a0")
))
.setLimitSpec(
new DefaultLimitSpec(
Collections.singletonList(
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
),
Integer.MAX_VALUE
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1L},
new Object[]{1L},
new Object[]{1L},
new Object[]{1L},
new Object[]{1L},
new Object[]{1L}
)
);
}
private void testQuery( private void testQuery(
final String sql, final String sql,
final List<Query> expectedQueries, final List<Query> expectedQueries,