Fix APPROX_QUANTILE on outer groupBys. (#5253)

This commit is contained in:
Gian Merlino 2018-01-12 12:01:32 -08:00 committed by Fangjin Yang
parent 491f8cca81
commit a11049c82f
2 changed files with 83 additions and 1 deletions

View File

@ -93,15 +93,27 @@ public class QuantileSqlAggregator implements SqlAggregator
project, project,
aggregateCall.getArgList().get(1) aggregateCall.getArgList().get(1)
); );
final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue();
if (!probabilityArg.isA(SqlKind.LITERAL)) {
// Probability must be a literal in order to plan.
return null;
}
final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue();
final int resolution; final int resolution;
if (aggregateCall.getArgList().size() >= 3) { if (aggregateCall.getArgList().size() >= 3) {
final RexNode resolutionArg = Expressions.fromFieldAccess( final RexNode resolutionArg = Expressions.fromFieldAccess(
rowSignature, rowSignature,
project, project,
aggregateCall.getArgList().get(2) aggregateCall.getArgList().get(2)
); );
if (!resolutionArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan.
return null;
}
resolution = ((Number) RexLiteral.value(resolutionArg)).intValue(); resolution = ((Number) RexLiteral.value(resolutionArg)).intValue();
} else { } else {
resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE; resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE;

View File

@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import io.druid.java.util.common.granularity.Granularities; import io.druid.java.util.common.granularity.Granularities;
import io.druid.query.Druids; import io.druid.query.Druids;
import io.druid.query.QueryDataSource;
import io.druid.query.aggregation.CountAggregatorFactory; import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.aggregation.DoubleSumAggregatorFactory; import io.druid.query.aggregation.DoubleSumAggregatorFactory;
import io.druid.query.aggregation.FilteredAggregatorFactory; import io.druid.query.aggregation.FilteredAggregatorFactory;
@ -32,9 +33,13 @@ import io.druid.query.aggregation.histogram.ApproximateHistogramAggregatorFactor
import io.druid.query.aggregation.histogram.ApproximateHistogramDruidModule; import io.druid.query.aggregation.histogram.ApproximateHistogramDruidModule;
import io.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory; import io.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory;
import io.druid.query.aggregation.histogram.QuantilePostAggregator; import io.druid.query.aggregation.histogram.QuantilePostAggregator;
import io.druid.query.aggregation.post.ArithmeticPostAggregator;
import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import io.druid.query.dimension.DefaultDimensionSpec;
import io.druid.query.expression.TestExprMacroTable; import io.druid.query.expression.TestExprMacroTable;
import io.druid.query.filter.NotDimFilter; import io.druid.query.filter.NotDimFilter;
import io.druid.query.filter.SelectorDimFilter; import io.druid.query.filter.SelectorDimFilter;
import io.druid.query.groupby.GroupByQuery;
import io.druid.query.spec.MultipleIntervalSegmentSpec; import io.druid.query.spec.MultipleIntervalSegmentSpec;
import io.druid.segment.IndexBuilder; import io.druid.segment.IndexBuilder;
import io.druid.segment.QueryableIndex; import io.druid.segment.QueryableIndex;
@ -291,4 +296,69 @@ public class QuantileSqlAggregatorTest
); );
} }
} }
@Test
public void testQuantileOnInnerQuery() throws Exception
{
try (final DruidPlanner planner = plannerFactory.createPlanner(null)) {
final String sql = "SELECT AVG(x), APPROX_QUANTILE(x, 0.98)\n"
+ "FROM (SELECT dim2, SUM(m1) AS x FROM foo GROUP BY dim2)";
final PlannerResult plannerResult = planner.plan(sql);
// Verify results
final List<Object[]> results = plannerResult.run().toList();
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{7.0, 8.26386833190918}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
// Verify query
Assert.assertEquals(
GroupByQuery.builder()
.setDataSource(
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL)
.setDimensions(ImmutableList.of(new DefaultDimensionSpec("dim2", "d0")))
.setAggregatorSpecs(
ImmutableList.of(
new DoubleSumAggregatorFactory("a0", "m1")
)
)
.setContext(ImmutableMap.of())
.build()
)
)
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(ImmutableList.of(
new DoubleSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count"),
new ApproximateHistogramAggregatorFactory("_a1:agg", "a0", null, null, null, null)
))
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(
"_a0",
"quotient",
ImmutableList.of(
new FieldAccessPostAggregator(null, "_a0:sum"),
new FieldAccessPostAggregator(null, "_a0:count")
)
),
new QuantilePostAggregator("_a1", "_a1:agg", 0.98f)
)
)
.setContext(ImmutableMap.of())
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
}
} }