mirror of https://github.com/apache/druid.git
Support projection after sorting in SQL (#5788)
* Add sort project * add more test * address comments
This commit is contained in:
parent
e43e5ebbcd
commit
fe4d678aac
|
@ -36,6 +36,7 @@ import io.druid.sql.calcite.filtration.Filtration;
|
|||
import io.druid.sql.calcite.table.RowSignature;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
@ -112,7 +113,7 @@ public class Aggregation
|
|||
|
||||
public static Aggregation create(final PostAggregator postAggregator)
|
||||
{
|
||||
return new Aggregation(ImmutableList.of(), ImmutableList.of(), postAggregator);
|
||||
return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
|
||||
}
|
||||
|
||||
public static Aggregation create(
|
||||
|
|
|
@ -89,6 +89,7 @@ import javax.annotation.Nullable;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.OptionalInt;
|
||||
import java.util.TreeSet;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
@ -105,9 +106,11 @@ public class DruidQuery
|
|||
private final DimFilter filter;
|
||||
private final SelectProjection selectProjection;
|
||||
private final Grouping grouping;
|
||||
private final SortProject sortProject;
|
||||
private final DefaultLimitSpec limitSpec;
|
||||
private final RowSignature outputRowSignature;
|
||||
private final RelDataType outputRowType;
|
||||
private final DefaultLimitSpec limitSpec;
|
||||
|
||||
private final Query query;
|
||||
|
||||
public DruidQuery(
|
||||
|
@ -128,15 +131,22 @@ public class DruidQuery
|
|||
this.selectProjection = computeSelectProjection(partialQuery, plannerContext, sourceRowSignature);
|
||||
this.grouping = computeGrouping(partialQuery, plannerContext, sourceRowSignature, rexBuilder);
|
||||
|
||||
final RowSignature sortingInputRowSignature;
|
||||
|
||||
if (this.selectProjection != null) {
|
||||
this.outputRowSignature = this.selectProjection.getOutputRowSignature();
|
||||
sortingInputRowSignature = this.selectProjection.getOutputRowSignature();
|
||||
} else if (this.grouping != null) {
|
||||
this.outputRowSignature = this.grouping.getOutputRowSignature();
|
||||
sortingInputRowSignature = this.grouping.getOutputRowSignature();
|
||||
} else {
|
||||
this.outputRowSignature = sourceRowSignature;
|
||||
sortingInputRowSignature = sourceRowSignature;
|
||||
}
|
||||
|
||||
this.limitSpec = computeLimitSpec(partialQuery, this.outputRowSignature);
|
||||
this.sortProject = computeSortProject(partialQuery, plannerContext, sortingInputRowSignature, grouping);
|
||||
|
||||
// outputRowSignature is used only for scan and select query, and thus sort and grouping must be null
|
||||
this.outputRowSignature = sortProject == null ? sortingInputRowSignature : sortProject.getOutputRowSignature();
|
||||
|
||||
this.limitSpec = computeLimitSpec(partialQuery, sortingInputRowSignature);
|
||||
this.query = computeQuery();
|
||||
}
|
||||
|
||||
|
@ -235,7 +245,7 @@ public class DruidQuery
|
|||
)
|
||||
{
|
||||
final Aggregate aggregate = partialQuery.getAggregate();
|
||||
final Project postProject = partialQuery.getPostProject();
|
||||
final Project aggregateProject = partialQuery.getAggregateProject();
|
||||
|
||||
if (aggregate == null) {
|
||||
return null;
|
||||
|
@ -265,49 +275,27 @@ public class DruidQuery
|
|||
plannerContext
|
||||
);
|
||||
|
||||
if (postProject == null) {
|
||||
if (aggregateProject == null) {
|
||||
return Grouping.create(dimensions, aggregations, havingFilter, aggregateRowSignature);
|
||||
} else {
|
||||
final List<String> rowOrder = new ArrayList<>();
|
||||
|
||||
int outputNameCounter = 0;
|
||||
for (final RexNode postAggregatorRexNode : postProject.getChildExps()) {
|
||||
// Attempt to convert to PostAggregator.
|
||||
final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
|
||||
plannerContext,
|
||||
aggregateRowSignature,
|
||||
postAggregatorRexNode
|
||||
);
|
||||
|
||||
if (postAggregatorExpression == null) {
|
||||
throw new CannotBuildQueryException(postProject, postAggregatorRexNode);
|
||||
}
|
||||
|
||||
if (postAggregatorDirectColumnIsOk(aggregateRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
|
||||
// Direct column access, without any type cast as far as Druid's runtime is concerned.
|
||||
// (There might be a SQL-level type cast that we don't care about)
|
||||
rowOrder.add(postAggregatorExpression.getDirectColumn());
|
||||
} else {
|
||||
final String postAggregatorName = "p" + outputNameCounter++;
|
||||
final PostAggregator postAggregator = new ExpressionPostAggregator(
|
||||
postAggregatorName,
|
||||
postAggregatorExpression.getExpression(),
|
||||
null,
|
||||
plannerContext.getExprMacroTable()
|
||||
);
|
||||
aggregations.add(Aggregation.create(postAggregator));
|
||||
rowOrder.add(postAggregator.getName());
|
||||
}
|
||||
}
|
||||
final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
|
||||
plannerContext,
|
||||
aggregateRowSignature,
|
||||
aggregateProject,
|
||||
0
|
||||
);
|
||||
projectRowOrderAndPostAggregations.postAggregations.forEach(
|
||||
postAggregator -> aggregations.add(Aggregation.create(postAggregator))
|
||||
);
|
||||
|
||||
// Remove literal dimensions that did not appear in the projection. This is useful for queries
|
||||
// like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can generate, and for which we don't
|
||||
// actually want to include a dimension 'dummy'.
|
||||
final ImmutableBitSet postProjectBits = RelOptUtil.InputFinder.bits(postProject.getChildExps(), null);
|
||||
final ImmutableBitSet aggregateProjectBits = RelOptUtil.InputFinder.bits(aggregateProject.getChildExps(), null);
|
||||
for (int i = dimensions.size() - 1; i >= 0; i--) {
|
||||
final DimensionExpression dimension = dimensions.get(i);
|
||||
if (Parser.parse(dimension.getDruidExpression().getExpression(), plannerContext.getExprMacroTable())
|
||||
.isLiteral() && !postProjectBits.get(i)) {
|
||||
.isLiteral() && !aggregateProjectBits.get(i)) {
|
||||
dimensions.remove(i);
|
||||
}
|
||||
}
|
||||
|
@ -316,11 +304,98 @@ public class DruidQuery
|
|||
dimensions,
|
||||
aggregations,
|
||||
havingFilter,
|
||||
RowSignature.from(rowOrder, postProject.getRowType())
|
||||
RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, aggregateProject.getRowType())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private SortProject computeSortProject(
|
||||
PartialDruidQuery partialQuery,
|
||||
PlannerContext plannerContext,
|
||||
RowSignature sortingInputRowSignature,
|
||||
Grouping grouping
|
||||
)
|
||||
{
|
||||
final Project sortProject = partialQuery.getSortProject();
|
||||
if (sortProject == null) {
|
||||
return null;
|
||||
} else {
|
||||
final List<PostAggregator> postAggregators = grouping.getPostAggregators();
|
||||
final OptionalInt maybeMaxCounter = postAggregators
|
||||
.stream()
|
||||
.mapToInt(postAggregator -> Integer.parseInt(postAggregator.getName().substring(1)))
|
||||
.max();
|
||||
|
||||
final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
|
||||
plannerContext,
|
||||
sortingInputRowSignature,
|
||||
sortProject,
|
||||
maybeMaxCounter.orElse(-1) + 1 // 0 if max doesn't exist
|
||||
);
|
||||
|
||||
return new SortProject(
|
||||
sortingInputRowSignature,
|
||||
projectRowOrderAndPostAggregations.postAggregations,
|
||||
RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, sortProject.getRowType())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private static class ProjectRowOrderAndPostAggregations
|
||||
{
|
||||
private final List<String> rowOrder;
|
||||
private final List<PostAggregator> postAggregations;
|
||||
|
||||
ProjectRowOrderAndPostAggregations(List<String> rowOrder, List<PostAggregator> postAggregations)
|
||||
{
|
||||
this.rowOrder = rowOrder;
|
||||
this.postAggregations = postAggregations;
|
||||
}
|
||||
}
|
||||
|
||||
private static ProjectRowOrderAndPostAggregations computePostAggregations(
|
||||
PlannerContext plannerContext,
|
||||
RowSignature inputRowSignature,
|
||||
Project project,
|
||||
int outputNameCounter
|
||||
)
|
||||
{
|
||||
final List<String> rowOrder = new ArrayList<>();
|
||||
final List<PostAggregator> aggregations = new ArrayList<>();
|
||||
|
||||
for (final RexNode postAggregatorRexNode : project.getChildExps()) {
|
||||
// Attempt to convert to PostAggregator.
|
||||
final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
|
||||
plannerContext,
|
||||
inputRowSignature,
|
||||
postAggregatorRexNode
|
||||
);
|
||||
|
||||
if (postAggregatorExpression == null) {
|
||||
throw new CannotBuildQueryException(project, postAggregatorRexNode);
|
||||
}
|
||||
|
||||
if (postAggregatorDirectColumnIsOk(inputRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
|
||||
// Direct column access, without any type cast as far as Druid's runtime is concerned.
|
||||
// (There might be a SQL-level type cast that we don't care about)
|
||||
rowOrder.add(postAggregatorExpression.getDirectColumn());
|
||||
} else {
|
||||
final String postAggregatorName = "p" + outputNameCounter++;
|
||||
final PostAggregator postAggregator = new ExpressionPostAggregator(
|
||||
postAggregatorName,
|
||||
postAggregatorExpression.getExpression(),
|
||||
null,
|
||||
plannerContext.getExprMacroTable()
|
||||
);
|
||||
aggregations.add(postAggregator);
|
||||
rowOrder.add(postAggregator.getName());
|
||||
}
|
||||
}
|
||||
|
||||
return new ProjectRowOrderAndPostAggregations(rowOrder, aggregations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns dimensions corresponding to {@code aggregate.getGroupSet()}, in the same order.
|
||||
*
|
||||
|
@ -540,18 +615,20 @@ public class DruidQuery
|
|||
{
|
||||
final List<VirtualColumn> retVal = new ArrayList<>();
|
||||
|
||||
if (grouping != null) {
|
||||
if (includeDimensions) {
|
||||
for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
|
||||
retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
|
||||
if (selectProjection != null) {
|
||||
retVal.addAll(selectProjection.getVirtualColumns());
|
||||
} else {
|
||||
if (grouping != null) {
|
||||
if (includeDimensions) {
|
||||
for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
|
||||
retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
|
||||
}
|
||||
}
|
||||
|
||||
for (Aggregation aggregation : grouping.getAggregations()) {
|
||||
retVal.addAll(aggregation.getVirtualColumns());
|
||||
}
|
||||
}
|
||||
|
||||
for (Aggregation aggregation : grouping.getAggregations()) {
|
||||
retVal.addAll(aggregation.getVirtualColumns());
|
||||
}
|
||||
} else if (selectProjection != null) {
|
||||
retVal.addAll(selectProjection.getVirtualColumns());
|
||||
}
|
||||
|
||||
return VirtualColumns.create(retVal);
|
||||
|
@ -567,6 +644,11 @@ public class DruidQuery
|
|||
return limitSpec;
|
||||
}
|
||||
|
||||
public SortProject getSortProject()
|
||||
{
|
||||
return sortProject;
|
||||
}
|
||||
|
||||
public RelDataType getOutputRowType()
|
||||
{
|
||||
return outputRowType;
|
||||
|
@ -667,7 +749,6 @@ public class DruidQuery
|
|||
|
||||
if (limitSpec != null) {
|
||||
// If there is a limit spec, timeseries cannot LIMIT; and must be ORDER BY time (or nothing).
|
||||
|
||||
if (limitSpec.isLimited()) {
|
||||
return null;
|
||||
}
|
||||
|
@ -797,6 +878,11 @@ public class DruidQuery
|
|||
|
||||
final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature);
|
||||
|
||||
final List<PostAggregator> postAggregators = new ArrayList<>(grouping.getPostAggregators());
|
||||
if (sortProject != null) {
|
||||
postAggregators.addAll(sortProject.getPostAggregators());
|
||||
}
|
||||
|
||||
return new GroupByQuery(
|
||||
dataSource,
|
||||
filtration.getQuerySegmentSpec(),
|
||||
|
@ -805,7 +891,7 @@ public class DruidQuery
|
|||
Granularities.ALL,
|
||||
grouping.getDimensionSpecs(),
|
||||
grouping.getAggregatorFactories(),
|
||||
grouping.getPostAggregators(),
|
||||
postAggregators,
|
||||
grouping.getHavingFilter() != null ? new DimFilterHavingSpec(grouping.getHavingFilter(), true) : null,
|
||||
limitSpec,
|
||||
ImmutableSortedMap.copyOf(plannerContext.getQueryContext())
|
||||
|
|
|
@ -220,14 +220,18 @@ public class DruidQueryRel extends DruidRel<DruidQueryRel>
|
|||
cost += COST_PER_COLUMN * partialQuery.getAggregate().getAggCallList().size();
|
||||
}
|
||||
|
||||
if (partialQuery.getPostProject() != null) {
|
||||
cost += COST_PER_COLUMN * partialQuery.getPostProject().getChildExps().size();
|
||||
if (partialQuery.getAggregateProject() != null) {
|
||||
cost += COST_PER_COLUMN * partialQuery.getAggregateProject().getChildExps().size();
|
||||
}
|
||||
|
||||
if (partialQuery.getSort() != null && partialQuery.getSort().fetch != null) {
|
||||
cost *= COST_LIMIT_MULTIPLIER;
|
||||
}
|
||||
|
||||
if (partialQuery.getSortProject() != null) {
|
||||
cost += COST_PER_COLUMN * partialQuery.getSortProject().getChildExps().size();
|
||||
}
|
||||
|
||||
if (partialQuery.getHavingFilter() != null) {
|
||||
cost *= COST_HAVING_MULTIPLIER;
|
||||
}
|
||||
|
|
|
@ -358,8 +358,12 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
|
|||
newPartialQuery = newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter());
|
||||
}
|
||||
|
||||
if (leftPartialQuery.getPostProject() != null) {
|
||||
newPartialQuery = newPartialQuery.withPostProject(leftPartialQuery.getPostProject());
|
||||
if (leftPartialQuery.getAggregateProject() != null) {
|
||||
newPartialQuery = newPartialQuery.withAggregateProject(leftPartialQuery.getAggregateProject());
|
||||
}
|
||||
|
||||
if (leftPartialQuery.getSortProject() != null) {
|
||||
newPartialQuery = newPartialQuery.withSortProject(leftPartialQuery.getSortProject());
|
||||
}
|
||||
|
||||
if (leftPartialQuery.getSort() != null) {
|
||||
|
|
|
@ -46,8 +46,9 @@ public class PartialDruidQuery
|
|||
private final Sort selectSort;
|
||||
private final Aggregate aggregate;
|
||||
private final Filter havingFilter;
|
||||
private final Project postProject;
|
||||
private final Project aggregateProject;
|
||||
private final Sort sort;
|
||||
private final Project sortProject;
|
||||
|
||||
public enum Stage
|
||||
{
|
||||
|
@ -57,8 +58,9 @@ public class PartialDruidQuery
|
|||
SELECT_SORT,
|
||||
AGGREGATE,
|
||||
HAVING_FILTER,
|
||||
POST_PROJECT,
|
||||
SORT
|
||||
AGGREGATE_PROJECT,
|
||||
SORT,
|
||||
SORT_PROJECT
|
||||
}
|
||||
|
||||
public PartialDruidQuery(
|
||||
|
@ -67,9 +69,10 @@ public class PartialDruidQuery
|
|||
final Project selectProject,
|
||||
final Sort selectSort,
|
||||
final Aggregate aggregate,
|
||||
final Project postProject,
|
||||
final Project aggregateProject,
|
||||
final Filter havingFilter,
|
||||
final Sort sort
|
||||
final Sort sort,
|
||||
final Project sortProject
|
||||
)
|
||||
{
|
||||
this.scan = Preconditions.checkNotNull(scan, "scan");
|
||||
|
@ -77,14 +80,15 @@ public class PartialDruidQuery
|
|||
this.selectProject = selectProject;
|
||||
this.selectSort = selectSort;
|
||||
this.aggregate = aggregate;
|
||||
this.postProject = postProject;
|
||||
this.aggregateProject = aggregateProject;
|
||||
this.havingFilter = havingFilter;
|
||||
this.sort = sort;
|
||||
this.sortProject = sortProject;
|
||||
}
|
||||
|
||||
public static PartialDruidQuery create(final RelNode scanRel)
|
||||
{
|
||||
return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null);
|
||||
return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null, null);
|
||||
}
|
||||
|
||||
public RelNode getScan()
|
||||
|
@ -117,9 +121,9 @@ public class PartialDruidQuery
|
|||
return havingFilter;
|
||||
}
|
||||
|
||||
public Project getPostProject()
|
||||
public Project getAggregateProject()
|
||||
{
|
||||
return postProject;
|
||||
return aggregateProject;
|
||||
}
|
||||
|
||||
public Sort getSort()
|
||||
|
@ -127,6 +131,11 @@ public class PartialDruidQuery
|
|||
return sort;
|
||||
}
|
||||
|
||||
public Project getSortProject()
|
||||
{
|
||||
return sortProject;
|
||||
}
|
||||
|
||||
public PartialDruidQuery withWhereFilter(final Filter newWhereFilter)
|
||||
{
|
||||
validateStage(Stage.WHERE_FILTER);
|
||||
|
@ -136,9 +145,10 @@ public class PartialDruidQuery
|
|||
selectProject,
|
||||
selectSort,
|
||||
aggregate,
|
||||
postProject,
|
||||
aggregateProject,
|
||||
havingFilter,
|
||||
sort
|
||||
sort,
|
||||
sortProject
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -151,9 +161,10 @@ public class PartialDruidQuery
|
|||
newSelectProject,
|
||||
selectSort,
|
||||
aggregate,
|
||||
postProject,
|
||||
aggregateProject,
|
||||
havingFilter,
|
||||
sort
|
||||
sort,
|
||||
sortProject
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -166,9 +177,10 @@ public class PartialDruidQuery
|
|||
selectProject,
|
||||
newSelectSort,
|
||||
aggregate,
|
||||
postProject,
|
||||
aggregateProject,
|
||||
havingFilter,
|
||||
sort
|
||||
sort,
|
||||
sortProject
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -181,9 +193,10 @@ public class PartialDruidQuery
|
|||
selectProject,
|
||||
selectSort,
|
||||
newAggregate,
|
||||
postProject,
|
||||
aggregateProject,
|
||||
havingFilter,
|
||||
sort
|
||||
sort,
|
||||
sortProject
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -196,24 +209,26 @@ public class PartialDruidQuery
|
|||
selectProject,
|
||||
selectSort,
|
||||
aggregate,
|
||||
postProject,
|
||||
aggregateProject,
|
||||
newHavingFilter,
|
||||
sort
|
||||
sort,
|
||||
sortProject
|
||||
);
|
||||
}
|
||||
|
||||
public PartialDruidQuery withPostProject(final Project newPostProject)
|
||||
public PartialDruidQuery withAggregateProject(final Project newAggregateProject)
|
||||
{
|
||||
validateStage(Stage.POST_PROJECT);
|
||||
validateStage(Stage.AGGREGATE_PROJECT);
|
||||
return new PartialDruidQuery(
|
||||
scan,
|
||||
whereFilter,
|
||||
selectProject,
|
||||
selectSort,
|
||||
aggregate,
|
||||
newPostProject,
|
||||
newAggregateProject,
|
||||
havingFilter,
|
||||
sort
|
||||
sort,
|
||||
sortProject
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -226,9 +241,26 @@ public class PartialDruidQuery
|
|||
selectProject,
|
||||
selectSort,
|
||||
aggregate,
|
||||
postProject,
|
||||
aggregateProject,
|
||||
havingFilter,
|
||||
newSort
|
||||
newSort,
|
||||
sortProject
|
||||
);
|
||||
}
|
||||
|
||||
public PartialDruidQuery withSortProject(final Project newSortProject)
|
||||
{
|
||||
validateStage(Stage.SORT_PROJECT);
|
||||
return new PartialDruidQuery(
|
||||
scan,
|
||||
whereFilter,
|
||||
selectProject,
|
||||
selectSort,
|
||||
aggregate,
|
||||
aggregateProject,
|
||||
havingFilter,
|
||||
sort,
|
||||
newSortProject
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -265,6 +297,9 @@ public class PartialDruidQuery
|
|||
} else if (stage.compareTo(Stage.AGGREGATE) >= 0 && selectSort != null) {
|
||||
// Cannot do any aggregations after a select + sort.
|
||||
return false;
|
||||
} else if (stage.compareTo(Stage.SORT) > 0 && sort == null) {
|
||||
// Cannot add sort project without a sort
|
||||
return false;
|
||||
} else {
|
||||
// Looks good.
|
||||
return true;
|
||||
|
@ -277,12 +312,15 @@ public class PartialDruidQuery
|
|||
*
|
||||
* @return stage
|
||||
*/
|
||||
@SuppressWarnings("VariableNotUsedInsideIf")
|
||||
public Stage stage()
|
||||
{
|
||||
if (sort != null) {
|
||||
if (sortProject != null) {
|
||||
return Stage.SORT_PROJECT;
|
||||
} else if (sort != null) {
|
||||
return Stage.SORT;
|
||||
} else if (postProject != null) {
|
||||
return Stage.POST_PROJECT;
|
||||
} else if (aggregateProject != null) {
|
||||
return Stage.AGGREGATE_PROJECT;
|
||||
} else if (havingFilter != null) {
|
||||
return Stage.HAVING_FILTER;
|
||||
} else if (aggregate != null) {
|
||||
|
@ -308,10 +346,12 @@ public class PartialDruidQuery
|
|||
final Stage currentStage = stage();
|
||||
|
||||
switch (currentStage) {
|
||||
case SORT_PROJECT:
|
||||
return sortProject;
|
||||
case SORT:
|
||||
return sort;
|
||||
case POST_PROJECT:
|
||||
return postProject;
|
||||
case AGGREGATE_PROJECT:
|
||||
return aggregateProject;
|
||||
case HAVING_FILTER:
|
||||
return havingFilter;
|
||||
case AGGREGATE:
|
||||
|
@ -352,14 +392,25 @@ public class PartialDruidQuery
|
|||
Objects.equals(selectSort, that.selectSort) &&
|
||||
Objects.equals(aggregate, that.aggregate) &&
|
||||
Objects.equals(havingFilter, that.havingFilter) &&
|
||||
Objects.equals(postProject, that.postProject) &&
|
||||
Objects.equals(sort, that.sort);
|
||||
Objects.equals(aggregateProject, that.aggregateProject) &&
|
||||
Objects.equals(sort, that.sort) &&
|
||||
Objects.equals(sortProject, that.sortProject);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode()
|
||||
{
|
||||
return Objects.hash(scan, whereFilter, selectProject, selectSort, aggregate, havingFilter, postProject, sort);
|
||||
return Objects.hash(
|
||||
scan,
|
||||
whereFilter,
|
||||
selectProject,
|
||||
selectSort,
|
||||
aggregate,
|
||||
havingFilter,
|
||||
aggregateProject,
|
||||
sort,
|
||||
sortProject
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -372,8 +423,9 @@ public class PartialDruidQuery
|
|||
", selectSort=" + selectSort +
|
||||
", aggregate=" + aggregate +
|
||||
", havingFilter=" + havingFilter +
|
||||
", postProject=" + postProject +
|
||||
", aggregateProject=" + aggregateProject +
|
||||
", sort=" + sort +
|
||||
", sortProject=" + sortProject +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. Metamarkets licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package io.druid.sql.calcite.rel;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.druid.java.util.common.ISE;
|
||||
import io.druid.query.aggregation.PostAggregator;
|
||||
import io.druid.sql.calcite.table.RowSignature;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
public class SortProject
|
||||
{
|
||||
private final RowSignature inputRowSignature;
|
||||
private final List<PostAggregator> postAggregators;
|
||||
private final RowSignature outputRowSignature;
|
||||
|
||||
SortProject(
|
||||
RowSignature inputRowSignature,
|
||||
List<PostAggregator> postAggregators,
|
||||
RowSignature outputRowSignature
|
||||
)
|
||||
{
|
||||
this.inputRowSignature = Preconditions.checkNotNull(inputRowSignature, "inputRowSignature");
|
||||
this.postAggregators = Preconditions.checkNotNull(postAggregators, "postAggregators");
|
||||
this.outputRowSignature = Preconditions.checkNotNull(outputRowSignature, "outputRowSignature");
|
||||
|
||||
// Verify no collisions.
|
||||
final Set<String> seen = new HashSet<>();
|
||||
inputRowSignature.getRowOrder().forEach(field -> {
|
||||
if (!seen.add(field)) {
|
||||
throw new ISE("Duplicate field name: %s", field);
|
||||
}
|
||||
});
|
||||
|
||||
for (PostAggregator postAggregator : postAggregators) {
|
||||
if (postAggregator == null) {
|
||||
throw new ISE("aggregation[%s] is not a postAggregator", postAggregator);
|
||||
}
|
||||
if (!seen.add(postAggregator.getName())) {
|
||||
throw new ISE("Duplicate field name: %s", postAggregator.getName());
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that items in the output signature exist.
|
||||
outputRowSignature.getRowOrder().forEach(field -> {
|
||||
if (!seen.contains(field)) {
|
||||
throw new ISE("Missing field in rowOrder: %s", field);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public List<PostAggregator> getPostAggregators()
|
||||
{
|
||||
return postAggregators;
|
||||
}
|
||||
|
||||
public RowSignature getOutputRowSignature()
|
||||
{
|
||||
return outputRowSignature;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o)
|
||||
{
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
SortProject sortProject = (SortProject) o;
|
||||
return Objects.equals(inputRowSignature, sortProject.inputRowSignature) &&
|
||||
Objects.equals(postAggregators, sortProject.postAggregators) &&
|
||||
Objects.equals(outputRowSignature, sortProject.outputRowSignature);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode()
|
||||
{
|
||||
return Objects.hash(inputRowSignature, postAggregators, outputRowSignature);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString()
|
||||
{
|
||||
return "SortProject{" +
|
||||
"inputRowSignature=" + inputRowSignature +
|
||||
", postAggregators=" + postAggregators +
|
||||
", outputRowSignature=" + outputRowSignature +
|
||||
'}';
|
||||
}
|
||||
}
|
|
@ -68,8 +68,8 @@ public class DruidRules
|
|||
),
|
||||
new DruidQueryRule<>(
|
||||
Project.class,
|
||||
PartialDruidQuery.Stage.POST_PROJECT,
|
||||
PartialDruidQuery::withPostProject
|
||||
PartialDruidQuery.Stage.AGGREGATE_PROJECT,
|
||||
PartialDruidQuery::withAggregateProject
|
||||
),
|
||||
new DruidQueryRule<>(
|
||||
Filter.class,
|
||||
|
@ -81,10 +81,16 @@ public class DruidRules
|
|||
PartialDruidQuery.Stage.SORT,
|
||||
PartialDruidQuery::withSort
|
||||
),
|
||||
new DruidQueryRule<>(
|
||||
Project.class,
|
||||
PartialDruidQuery.Stage.SORT_PROJECT,
|
||||
PartialDruidQuery::withSortProject
|
||||
),
|
||||
DruidOuterQueryRule.AGGREGATE,
|
||||
DruidOuterQueryRule.FILTER_AGGREGATE,
|
||||
DruidOuterQueryRule.FILTER_PROJECT_AGGREGATE,
|
||||
DruidOuterQueryRule.PROJECT_AGGREGATE
|
||||
DruidOuterQueryRule.PROJECT_AGGREGATE,
|
||||
DruidOuterQueryRule.AGGREGATE_SORT_PROJECT
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -227,6 +233,32 @@ public class DruidRules
|
|||
}
|
||||
};
|
||||
|
||||
public static RelOptRule AGGREGATE_SORT_PROJECT = new DruidOuterQueryRule(
|
||||
operand(Project.class, operand(Sort.class, operand(Aggregate.class, operand(DruidRel.class, any())))),
|
||||
"AGGREGATE_SORT_PROJECT"
|
||||
)
|
||||
{
|
||||
@Override
|
||||
public void onMatch(RelOptRuleCall call)
|
||||
{
|
||||
final Project sortProject = call.rel(0);
|
||||
final Sort sort = call.rel(1);
|
||||
final Aggregate aggregate = call.rel(2);
|
||||
final DruidRel druidRel = call.rel(3);
|
||||
|
||||
final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
|
||||
druidRel,
|
||||
PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
|
||||
.withAggregate(aggregate)
|
||||
.withSort(sort)
|
||||
.withSortProject(sortProject)
|
||||
);
|
||||
if (outerQueryRel.isValidDruidQuery()) {
|
||||
call.transformTo(outerQueryRel);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
public DruidOuterQueryRule(final RelOptRuleOperand op, final String description)
|
||||
{
|
||||
super(op, StringUtils.format("%s:%s", DruidOuterQueryRel.class.getSimpleName(), description));
|
||||
|
|
|
@ -24,6 +24,7 @@ import com.google.common.base.Predicates;
|
|||
import io.druid.sql.calcite.planner.PlannerConfig;
|
||||
import io.druid.sql.calcite.rel.DruidRel;
|
||||
import io.druid.sql.calcite.rel.DruidSemiJoin;
|
||||
import io.druid.sql.calcite.rel.PartialDruidQuery;
|
||||
import org.apache.calcite.plan.RelOptRule;
|
||||
import org.apache.calcite.plan.RelOptRuleCall;
|
||||
import org.apache.calcite.plan.RelOptUtil;
|
||||
|
@ -115,15 +116,18 @@ public class DruidSemiJoinRule extends RelOptRule
|
|||
return;
|
||||
}
|
||||
|
||||
final Project rightPostProject = right.getPartialDruidQuery().getPostProject();
|
||||
final PartialDruidQuery rightQuery = right.getPartialDruidQuery();
|
||||
final Project rightProject = rightQuery.getSortProject() != null ?
|
||||
rightQuery.getSortProject() :
|
||||
rightQuery.getAggregateProject();
|
||||
int i = 0;
|
||||
for (int joinRef : joinInfo.rightSet()) {
|
||||
final int aggregateRef;
|
||||
|
||||
if (rightPostProject == null) {
|
||||
if (rightProject == null) {
|
||||
aggregateRef = joinRef;
|
||||
} else {
|
||||
final RexNode projectExp = rightPostProject.getChildExps().get(joinRef);
|
||||
final RexNode projectExp = rightProject.getChildExps().get(joinRef);
|
||||
if (projectExp.isA(SqlKind.INPUT_REF)) {
|
||||
aggregateRef = ((RexInputRef) projectExp).getIndex();
|
||||
} else {
|
||||
|
|
|
@ -70,6 +70,7 @@ import io.druid.query.groupby.GroupByQuery;
|
|||
import io.druid.query.groupby.having.DimFilterHavingSpec;
|
||||
import io.druid.query.groupby.orderby.DefaultLimitSpec;
|
||||
import io.druid.query.groupby.orderby.OrderByColumnSpec;
|
||||
import io.druid.query.groupby.orderby.OrderByColumnSpec.Direction;
|
||||
import io.druid.query.lookup.RegisteredLookupExtractionFn;
|
||||
import io.druid.query.ordering.StringComparator;
|
||||
import io.druid.query.ordering.StringComparators;
|
||||
|
@ -119,6 +120,7 @@ import org.junit.rules.TemporaryFolder;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -6529,6 +6531,193 @@ public class CalciteQueryTest extends CalciteTestBase
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testProjectAfterSort() throws Exception
|
||||
{
|
||||
testQuery(
|
||||
"select dim1 from (select dim1, dim2, count(*) cnt from druid.foo group by dim1, dim2 order by cnt)",
|
||||
ImmutableList.of(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(QSS(Filtration.eternity()))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setDimensions(
|
||||
DIMS(
|
||||
new DefaultDimensionSpec("dim1", "d0"),
|
||||
new DefaultDimensionSpec("dim2", "d1")
|
||||
)
|
||||
)
|
||||
.setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0")))
|
||||
.setLimitSpec(
|
||||
new DefaultLimitSpec(
|
||||
Collections.singletonList(
|
||||
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
|
||||
),
|
||||
Integer.MAX_VALUE
|
||||
)
|
||||
)
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{""},
|
||||
new Object[]{"1"},
|
||||
new Object[]{"10.1"},
|
||||
new Object[]{"2"},
|
||||
new Object[]{"abc"},
|
||||
new Object[]{"def"}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testProjectAfterSort2() throws Exception
|
||||
{
|
||||
testQuery(
|
||||
"select s / cnt, dim1, dim2, s from (select dim1, dim2, count(*) cnt, sum(m2) s from druid.foo group by dim1, dim2 order by cnt)",
|
||||
ImmutableList.of(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(QSS(Filtration.eternity()))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setDimensions(
|
||||
DIMS(
|
||||
new DefaultDimensionSpec("dim1", "d0"),
|
||||
new DefaultDimensionSpec("dim2", "d1")
|
||||
)
|
||||
)
|
||||
.setAggregatorSpecs(
|
||||
AGGS(new CountAggregatorFactory("a0"), new DoubleSumAggregatorFactory("a1", "m2"))
|
||||
)
|
||||
.setPostAggregatorSpecs(Collections.singletonList(EXPRESSION_POST_AGG("p0", "(\"a1\" / \"a0\")")))
|
||||
.setLimitSpec(
|
||||
new DefaultLimitSpec(
|
||||
Collections.singletonList(
|
||||
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
|
||||
),
|
||||
Integer.MAX_VALUE
|
||||
)
|
||||
)
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{1.0, "", "a", 1.0},
|
||||
new Object[]{4.0, "1", "a", 4.0},
|
||||
new Object[]{2.0, "10.1", "", 2.0},
|
||||
new Object[]{3.0, "2", "", 3.0},
|
||||
new Object[]{6.0, "abc", "", 6.0},
|
||||
new Object[]{5.0, "def", "abc", 5.0}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testProjectAfterSort3() throws Exception
|
||||
{
|
||||
testQuery(
|
||||
"select dim1 from (select dim1, dim1, count(*) cnt from druid.foo group by dim1, dim1 order by cnt)",
|
||||
ImmutableList.of(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(QSS(Filtration.eternity()))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setDimensions(
|
||||
DIMS(
|
||||
new DefaultDimensionSpec("dim1", "d0")
|
||||
)
|
||||
)
|
||||
.setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0")))
|
||||
.setLimitSpec(
|
||||
new DefaultLimitSpec(
|
||||
Collections.singletonList(
|
||||
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
|
||||
),
|
||||
Integer.MAX_VALUE
|
||||
)
|
||||
)
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{""},
|
||||
new Object[]{"1"},
|
||||
new Object[]{"10.1"},
|
||||
new Object[]{"2"},
|
||||
new Object[]{"abc"},
|
||||
new Object[]{"def"}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSortProjectAfterNestedGroupBy() throws Exception
|
||||
{
|
||||
testQuery(
|
||||
"SELECT "
|
||||
+ " cnt "
|
||||
+ "FROM ("
|
||||
+ " SELECT "
|
||||
+ " __time, "
|
||||
+ " dim1, "
|
||||
+ " COUNT(m2) AS cnt "
|
||||
+ " FROM ("
|
||||
+ " SELECT "
|
||||
+ " __time, "
|
||||
+ " m2, "
|
||||
+ " dim1 "
|
||||
+ " FROM druid.foo "
|
||||
+ " GROUP BY __time, m2, dim1 "
|
||||
+ " ) "
|
||||
+ " GROUP BY __time, dim1 "
|
||||
+ " ORDER BY cnt"
|
||||
+ ")",
|
||||
ImmutableList.of(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(QSS(Filtration.eternity()))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setDimensions(DIMS(
|
||||
new DefaultDimensionSpec("__time", "d0", ValueType.LONG),
|
||||
new DefaultDimensionSpec("dim1", "d1"),
|
||||
new DefaultDimensionSpec("m2", "d2", ValueType.DOUBLE)
|
||||
))
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
)
|
||||
.setInterval(QSS(Filtration.eternity()))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setDimensions(DIMS(
|
||||
new DefaultDimensionSpec("d0", "_d0", ValueType.LONG),
|
||||
new DefaultDimensionSpec("d1", "_d1", ValueType.STRING)
|
||||
))
|
||||
.setAggregatorSpecs(AGGS(
|
||||
new CountAggregatorFactory("a0")
|
||||
))
|
||||
.setLimitSpec(
|
||||
new DefaultLimitSpec(
|
||||
Collections.singletonList(
|
||||
new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC)
|
||||
),
|
||||
Integer.MAX_VALUE
|
||||
)
|
||||
)
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{1L},
|
||||
new Object[]{1L},
|
||||
new Object[]{1L},
|
||||
new Object[]{1L},
|
||||
new Object[]{1L},
|
||||
new Object[]{1L}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private void testQuery(
|
||||
final String sql,
|
||||
final List<Query> expectedQueries,
|
||||
|
|
Loading…
Reference in New Issue