diff --git a/benchmarks/src/main/java/io/druid/benchmark/query/SqlBenchmark.java b/benchmarks/src/main/java/io/druid/benchmark/query/SqlBenchmark.java index 4637abc7fbb..440cf6440cd 100644 --- a/benchmarks/src/main/java/io/druid/benchmark/query/SqlBenchmark.java +++ b/benchmarks/src/main/java/io/druid/benchmark/query/SqlBenchmark.java @@ -53,6 +53,7 @@ import io.druid.sql.calcite.planner.PlannerFactory; import io.druid.sql.calcite.planner.PlannerResult; import io.druid.sql.calcite.rel.QueryMaker; import io.druid.sql.calcite.table.DruidTable; +import io.druid.sql.calcite.table.RowSignature; import io.druid.sql.calcite.util.CalciteTests; import io.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker; import io.druid.timeline.DataSegment; @@ -158,12 +159,12 @@ public class SqlBenchmark new DruidTable( new QueryMaker(walker, plannerConfig), new TableDataSource("foo"), - ImmutableMap.of( - "__time", ValueType.LONG, - "dimSequential", ValueType.STRING, - "dimZipf", ValueType.STRING, - "dimUniform", ValueType.STRING - ) + RowSignature.builder() + .add("__time", ValueType.LONG) + .add("dimSequential", ValueType.STRING) + .add("dimZipf", ValueType.STRING) + .add("dimUniform", ValueType.STRING) + .build() ) ); final Schema druidSchema = new AbstractSchema() diff --git a/docs/content/querying/sql.md b/docs/content/querying/sql.md index b9b94ecf332..b0fb6eeafb8 100644 --- a/docs/content/querying/sql.md +++ b/docs/content/querying/sql.md @@ -149,8 +149,10 @@ Some Druid extensions also include SQL language extensions. If the [approximate histogram extension](../development/extensions-core/approximate-histograms.html) is loaded: -- `QUANTILE(column, probability)` on numeric or approximate histogram columns computes approximate quantiles. The -"probability" should be between 0 and 1 (exclusive). +- `APPROX_QUANTILE(column, probability)` or `APPROX_QUANTILE(column, probability, resolution)` on numeric or +approximate histogram columns computes approximate quantiles. The "probability" should be between 0 and 1 (exclusive). +The "resolution" is the number of centroids to use for the computation. Higher resolutions will be give more +precise results but also have higher overhead. If not provided, the default resolution is 50. ### Unsupported features diff --git a/extensions-core/histogram/src/main/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java index fe3896f3b77..aaf3745812c 100644 --- a/extensions-core/histogram/src/main/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java @@ -19,14 +19,17 @@ package io.druid.query.aggregation.histogram.sql; +import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.histogram.ApproximateHistogram; import io.druid.query.aggregation.histogram.ApproximateHistogramAggregatorFactory; import io.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory; import io.druid.query.aggregation.histogram.QuantilePostAggregator; +import io.druid.query.filter.DimFilter; import io.druid.segment.column.ValueType; import io.druid.sql.calcite.aggregation.Aggregation; +import io.druid.sql.calcite.aggregation.Aggregations; import io.druid.sql.calcite.aggregation.SqlAggregator; import io.druid.sql.calcite.expression.Expressions; import io.druid.sql.calcite.expression.RowExtraction; @@ -48,7 +51,7 @@ import java.util.List; public class QuantileSqlAggregator implements SqlAggregator { private static final SqlAggFunction FUNCTION_INSTANCE = new QuantileSqlAggFunction(); - private static final String NAME = "QUANTILE"; + private static final String NAME = "APPROX_QUANTILE"; @Override public SqlAggFunction calciteFunction() @@ -62,7 +65,8 @@ public class QuantileSqlAggregator implements SqlAggregator final RowSignature rowSignature, final List existingAggregations, final Project project, - final AggregateCall aggregateCall + final AggregateCall aggregateCall, + final DimFilter filter ) { final RowExtraction rex = Expressions.toRowExtraction( @@ -77,6 +81,8 @@ public class QuantileSqlAggregator implements SqlAggregator return null; } + final AggregatorFactory aggregatorFactory; + final String histogramName = String.format("%s:agg", name); final RexNode probabilityArg = Expressions.fromFieldAccess( rowSignature, project, @@ -84,10 +90,18 @@ public class QuantileSqlAggregator implements SqlAggregator ); final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue(); - final AggregatorFactory aggregatorFactory; - final String histogramName = String.format("%s:agg", name); + final int resolution; + if (aggregateCall.getArgList().size() >= 3) { + final RexNode resolutionArg = Expressions.fromFieldAccess( + rowSignature, + project, + aggregateCall.getArgList().get(2) + ); + resolution = ((Number) RexLiteral.value(resolutionArg)).intValue(); + } else { + resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE; + } - final int resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE; final int numBuckets = ApproximateHistogram.DEFAULT_BUCKET_SIZE; final float lowerLimit = Float.NEGATIVE_INFINITY; final float upperLimit = Float.POSITIVE_INFINITY; @@ -95,19 +109,30 @@ public class QuantileSqlAggregator implements SqlAggregator // Look for existing matching aggregatorFactory. for (final Aggregation existing : existingAggregations) { for (AggregatorFactory factory : existing.getAggregatorFactories()) { - if (factory instanceof ApproximateHistogramAggregatorFactory) { - final ApproximateHistogramAggregatorFactory theFactory = (ApproximateHistogramAggregatorFactory) factory; - if (theFactory.getFieldName().equals(rex.getColumn()) - && theFactory.getResolution() == resolution - && theFactory.getNumBuckets() == numBuckets - && theFactory.getLowerLimit() == lowerLimit - && theFactory.getUpperLimit() == upperLimit) { - // Found existing one. Use this. - return Aggregation.create( - ImmutableList.of(), - new QuantilePostAggregator(name, theFactory.getName(), probability) - ); - } + final boolean matches = Aggregations.aggregatorMatches( + factory, + filter, + ApproximateHistogramAggregatorFactory.class, + new Predicate() + { + @Override + public boolean apply(final ApproximateHistogramAggregatorFactory theFactory) + { + return theFactory.getFieldName().equals(rex.getColumn()) + && theFactory.getResolution() == resolution + && theFactory.getNumBuckets() == numBuckets + && theFactory.getLowerLimit() == lowerLimit + && theFactory.getUpperLimit() == upperLimit; + } + } + ); + + if (matches) { + // Found existing one. Use this. + return Aggregation.create( + ImmutableList.of(), + new QuantilePostAggregator(name, factory.getName(), probability) + ); } } } @@ -135,12 +160,13 @@ public class QuantileSqlAggregator implements SqlAggregator return Aggregation.create( ImmutableList.of(aggregatorFactory), new QuantilePostAggregator(name, histogramName, probability) - ); + ).filter(filter); } private static class QuantileSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE = "'" + NAME + "(column, probability)'"; + private static final String SIGNATURE1 = "'" + NAME + "(column, probability)'\n"; + private static final String SIGNATURE2 = "'" + NAME + "(column, probability, resolution)'\n"; QuantileSqlAggFunction() { @@ -150,9 +176,15 @@ public class QuantileSqlAggregator implements SqlAggregator SqlKind.OTHER_FUNCTION, ReturnTypes.explicit(SqlTypeName.DOUBLE), null, - OperandTypes.and( - OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL), - OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + OperandTypes.or( + OperandTypes.and( + OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL), + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + ), + OperandTypes.and( + OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL), + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC) + ) ), SqlFunctionCategory.NUMERIC, false, diff --git a/extensions-core/histogram/src/test/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java b/extensions-core/histogram/src/test/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java index 1e92d5e9d41..a597a48f385 100644 --- a/extensions-core/histogram/src/test/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java +++ b/extensions-core/histogram/src/test/java/io/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java @@ -26,10 +26,22 @@ import io.druid.granularity.QueryGranularities; import io.druid.java.util.common.guava.Sequences; import io.druid.query.Druids; import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.CountAggregatorFactory; +import io.druid.query.aggregation.DoubleSumAggregatorFactory; +import io.druid.query.aggregation.FilteredAggregatorFactory; import io.druid.query.aggregation.PostAggregator; import io.druid.query.aggregation.histogram.ApproximateHistogramAggregatorFactory; +import io.druid.query.aggregation.histogram.ApproximateHistogramDruidModule; +import io.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory; import io.druid.query.aggregation.histogram.QuantilePostAggregator; +import io.druid.query.filter.NotDimFilter; +import io.druid.query.filter.SelectorDimFilter; import io.druid.query.spec.MultipleIntervalSegmentSpec; +import io.druid.segment.IndexBuilder; +import io.druid.segment.QueryableIndex; +import io.druid.segment.TestHelper; +import io.druid.segment.incremental.IncrementalIndexSchema; +import io.druid.sql.calcite.CalciteQueryTest; import io.druid.sql.calcite.aggregation.SqlAggregator; import io.druid.sql.calcite.filtration.Filtration; import io.druid.sql.calcite.planner.Calcites; @@ -40,6 +52,8 @@ import io.druid.sql.calcite.planner.PlannerResult; import io.druid.sql.calcite.util.CalciteTests; import io.druid.sql.calcite.util.QueryLogHook; import io.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker; +import io.druid.timeline.DataSegment; +import io.druid.timeline.partition.LinearShardSpec; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.tools.Planner; import org.junit.After; @@ -52,10 +66,10 @@ import org.junit.rules.TemporaryFolder; import java.util.ArrayList; import java.util.List; -import static io.druid.sql.calcite.CalciteQueryTest.TIMESERIES_CONTEXT; - public class QuantileSqlAggregatorTest { + private static final String DATA_SOURCE = "foo"; + @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); @@ -69,7 +83,45 @@ public class QuantileSqlAggregatorTest public void setUp() throws Exception { Calcites.setSystemProperties(); - walker = CalciteTests.createMockWalker(temporaryFolder.newFolder()); + + // Note: this is needed in order to properly register the serde for Histogram. + new ApproximateHistogramDruidModule().configure(null); + + final QueryableIndex index = IndexBuilder.create() + .tmpDir(temporaryFolder.newFolder()) + .indexMerger(TestHelper.getTestIndexMergerV9()) + .schema( + new IncrementalIndexSchema.Builder() + .withMetrics( + new AggregatorFactory[]{ + new CountAggregatorFactory("cnt"), + new DoubleSumAggregatorFactory("m1", "m1"), + new ApproximateHistogramAggregatorFactory( + "hist_m1", + "m1", + null, + null, + null, + null + ) + } + ) + .withRollup(false) + .build() + ) + .rows(CalciteTests.ROWS1) + .buildMMappedIndex(); + + walker = new SpecificSegmentsQuerySegmentWalker(CalciteTests.queryRunnerFactoryConglomerate()).add( + DataSegment.builder() + .dataSource(DATA_SOURCE) + .interval(index.getDataInterval()) + .version("1") + .shardSpec(new LinearShardSpec(0)) + .build(), + index + ); + final PlannerConfig plannerConfig = new PlannerConfig(); final SchemaPlus rootSchema = Calcites.createRootSchema( CalciteTests.createMockSchema( @@ -96,13 +148,23 @@ public class QuantileSqlAggregatorTest public void testQuantileOnFloatAndLongs() throws Exception { try (final Planner planner = plannerFactory.createPlanner()) { - final String sql = "SELECT QUANTILE(m1, 0.01), QUANTILE(m1, 0.5), QUANTILE(m1, 0.99), QUANTILE(cnt, 0.5) FROM foo"; + final String sql = "SELECT\n" + + "APPROX_QUANTILE(m1, 0.01),\n" + + "APPROX_QUANTILE(m1, 0.5, 50),\n" + + "APPROX_QUANTILE(m1, 0.98, 200),\n" + + "APPROX_QUANTILE(m1, 0.99),\n" + + "APPROX_QUANTILE(m1, 0.99) FILTER(WHERE dim1 = 'abc'),\n" + + "APPROX_QUANTILE(m1, 0.999) FILTER(WHERE dim1 <> 'abc'),\n" + + "APPROX_QUANTILE(m1, 0.999) FILTER(WHERE dim1 = 'abc'),\n" + + "APPROX_QUANTILE(cnt, 0.5)\n" + + "FROM foo"; + final PlannerResult plannerResult = Calcites.plan(planner, sql); // Verify results final List results = Sequences.toList(plannerResult.run(), new ArrayList()); final List expectedResults = ImmutableList.of( - new Object[]{1.0, 3.0, 5.940000057220459, 1.0} + new Object[]{1.0, 3.0, 5.880000114440918, 5.940000057220459, 6.0, 4.994999885559082, 6.0, 1.0} ); Assert.assertEquals(expectedResults.size(), results.size()); for (int i = 0; i < expectedResults.size(); i++) { @@ -115,17 +177,90 @@ public class QuantileSqlAggregatorTest .dataSource(CalciteTests.DATASOURCE1) .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) .granularity(QueryGranularities.ALL) - .aggregators(ImmutableList.of( + .aggregators(ImmutableList.of( new ApproximateHistogramAggregatorFactory("a0:agg", "m1", null, null, null, null), - new ApproximateHistogramAggregatorFactory("a3:agg", "cnt", null, null, null, null) + new ApproximateHistogramAggregatorFactory("a2:agg", "m1", 200, null, null, null), + new FilteredAggregatorFactory( + new ApproximateHistogramAggregatorFactory("a4:agg", "m1", null, null, null, null), + new SelectorDimFilter("dim1", "abc", null) + ), + new FilteredAggregatorFactory( + new ApproximateHistogramAggregatorFactory("a5:agg", "m1", null, null, null, null), + new NotDimFilter(new SelectorDimFilter("dim1", "abc", null)) + ), + new ApproximateHistogramAggregatorFactory("a7:agg", "cnt", null, null, null, null) )) .postAggregators(ImmutableList.of( new QuantilePostAggregator("a0", "a0:agg", 0.01f), new QuantilePostAggregator("a1", "a0:agg", 0.50f), - new QuantilePostAggregator("a2", "a0:agg", 0.99f), - new QuantilePostAggregator("a3", "a3:agg", 0.50f) + new QuantilePostAggregator("a2", "a2:agg", 0.98f), + new QuantilePostAggregator("a3", "a0:agg", 0.99f), + new QuantilePostAggregator("a4", "a4:agg", 0.99f), + new QuantilePostAggregator("a5", "a5:agg", 0.999f), + new QuantilePostAggregator("a6", "a4:agg", 0.999f), + new QuantilePostAggregator("a7", "a7:agg", 0.50f) )) - .context(TIMESERIES_CONTEXT) + .context(CalciteQueryTest.TIMESERIES_CONTEXT) + .build(), + Iterables.getOnlyElement(queryLogHook.getRecordedQueries()) + ); + } + } + + @Test + public void testQuantileOnComplexColumn() throws Exception + { + try (final Planner planner = plannerFactory.createPlanner()) { + final String sql = "SELECT\n" + + "APPROX_QUANTILE(hist_m1, 0.01),\n" + + "APPROX_QUANTILE(hist_m1, 0.5, 50),\n" + + "APPROX_QUANTILE(hist_m1, 0.98, 200),\n" + + "APPROX_QUANTILE(hist_m1, 0.99),\n" + + "APPROX_QUANTILE(hist_m1, 0.99) FILTER(WHERE dim1 = 'abc'),\n" + + "APPROX_QUANTILE(hist_m1, 0.999) FILTER(WHERE dim1 <> 'abc'),\n" + + "APPROX_QUANTILE(hist_m1, 0.999) FILTER(WHERE dim1 = 'abc')\n" + + "FROM foo"; + + final PlannerResult plannerResult = Calcites.plan(planner, sql); + + // Verify results + final List results = Sequences.toList(plannerResult.run(), new ArrayList()); + final List expectedResults = ImmutableList.of( + new Object[]{1.0, 3.0, 5.880000114440918, 5.940000057220459, 6.0, 4.994999885559082, 6.0} + ); + Assert.assertEquals(expectedResults.size(), results.size()); + for (int i = 0; i < expectedResults.size(); i++) { + Assert.assertArrayEquals(expectedResults.get(i), results.get(i)); + } + + // Verify query + Assert.assertEquals( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))) + .granularity(QueryGranularities.ALL) + .aggregators(ImmutableList.of( + new ApproximateHistogramFoldingAggregatorFactory("a0:agg", "hist_m1", null, null, null, null), + new ApproximateHistogramFoldingAggregatorFactory("a2:agg", "hist_m1", 200, null, null, null), + new FilteredAggregatorFactory( + new ApproximateHistogramFoldingAggregatorFactory("a4:agg", "hist_m1", null, null, null, null), + new SelectorDimFilter("dim1", "abc", null) + ), + new FilteredAggregatorFactory( + new ApproximateHistogramFoldingAggregatorFactory("a5:agg", "hist_m1", null, null, null, null), + new NotDimFilter(new SelectorDimFilter("dim1", "abc", null)) + ) + )) + .postAggregators(ImmutableList.of( + new QuantilePostAggregator("a0", "a0:agg", 0.01f), + new QuantilePostAggregator("a1", "a0:agg", 0.50f), + new QuantilePostAggregator("a2", "a2:agg", 0.98f), + new QuantilePostAggregator("a3", "a0:agg", 0.99f), + new QuantilePostAggregator("a4", "a4:agg", 0.99f), + new QuantilePostAggregator("a5", "a5:agg", 0.999f), + new QuantilePostAggregator("a6", "a4:agg", 0.999f) + )) + .context(CalciteQueryTest.TIMESERIES_CONTEXT) .build(), Iterables.getOnlyElement(queryLogHook.getRecordedQueries()) ); diff --git a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregations.java b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregations.java new file mode 100644 index 00000000000..b833892bc68 --- /dev/null +++ b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregations.java @@ -0,0 +1,60 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.sql.calcite.aggregation; + +import com.google.common.base.Predicate; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.FilteredAggregatorFactory; +import io.druid.query.filter.DimFilter; + +public class Aggregations +{ + private Aggregations() + { + // No instantiation. + } + + /** + * Returns true if "factory" is an aggregator factory that either matches "predicate" (if filter is null) or is + * a filtered aggregator factory whose filter is equal to "filter" and underlying aggregator matches "predicate". + * + * @param factory factory to match + * @param filter filter, may be null + * @param clazz class of factory to match + * @param predicate predicate + * + * @return true if the aggregator matches filter + predicate + */ + public static boolean aggregatorMatches( + final AggregatorFactory factory, + final DimFilter filter, + final Class clazz, + final Predicate predicate + ) + { + if (filter != null) { + return factory instanceof FilteredAggregatorFactory && + ((FilteredAggregatorFactory) factory).getFilter().equals(filter) + && aggregatorMatches(((FilteredAggregatorFactory) factory).getAggregator(), null, clazz, predicate); + } else { + return clazz.isAssignableFrom(factory.getClass()) && predicate.apply(clazz.cast(factory)); + } + } +} diff --git a/sql/src/main/java/io/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java b/sql/src/main/java/io/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java index 36ed5d8cf62..beed16ae9a0 100644 --- a/sql/src/main/java/io/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java +++ b/sql/src/main/java/io/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java @@ -27,6 +27,7 @@ import io.druid.query.aggregation.cardinality.CardinalityAggregatorFactory; import io.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator; import io.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory; import io.druid.query.dimension.DimensionSpec; +import io.druid.query.filter.DimFilter; import io.druid.segment.column.ValueType; import io.druid.sql.calcite.expression.Expressions; import io.druid.sql.calcite.expression.RowExtraction; @@ -60,7 +61,8 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator final RowSignature rowSignature, final List existingAggregations, final Project project, - final AggregateCall aggregateCall + final AggregateCall aggregateCall, + final DimFilter filter ) { final RowExtraction rex = Expressions.toRowExtraction( @@ -99,7 +101,7 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator return new HyperUniqueFinalizingPostAggregator(outputName, name); } } - ); + ).filter(filter); } private static class ApproxCountDistinctSqlAggFunction extends SqlAggFunction diff --git a/sql/src/main/java/io/druid/sql/calcite/aggregation/SqlAggregator.java b/sql/src/main/java/io/druid/sql/calcite/aggregation/SqlAggregator.java index f8e75f1cbfb..0db32845c05 100644 --- a/sql/src/main/java/io/druid/sql/calcite/aggregation/SqlAggregator.java +++ b/sql/src/main/java/io/druid/sql/calcite/aggregation/SqlAggregator.java @@ -19,6 +19,7 @@ package io.druid.sql.calcite.aggregation; +import io.druid.query.filter.DimFilter; import io.druid.sql.calcite.table.RowSignature; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Project; @@ -44,9 +45,11 @@ public interface SqlAggregator * * @param name desired output name of the aggregation * @param rowSignature signature of the rows being aggregated - * @param existingAggregations existing aggregations for this query; useful for re-using aggregators - * @param project SQL projection to apply before the aggregate call + * @param existingAggregations existing aggregations for this query; useful for re-using aggregations. May be safely + * ignored if you do not want to re-use existing aggregations. + * @param project SQL projection to apply before the aggregate call, may be null * @param aggregateCall SQL aggregate call + * @param filter filter that should be applied to the aggregation, may be null * * @return aggregation, or null if the call cannot be translated */ @@ -56,6 +59,7 @@ public interface SqlAggregator final RowSignature rowSignature, final List existingAggregations, final Project project, - final AggregateCall aggregateCall + final AggregateCall aggregateCall, + final DimFilter filter ); } diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java index 875ffb10c43..14fd569eec5 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java @@ -739,7 +739,6 @@ public class GroupByRules final String name = aggOutputName(aggNumber); final SqlKind kind = call.getAggregation().getKind(); final SqlTypeName outputType = call.getType().getSqlTypeName(); - final Aggregation retVal; if (call.filterArg >= 0) { // AGG(xxx) FILTER(WHERE yyy) @@ -759,15 +758,16 @@ public class GroupByRules if (kind == SqlKind.COUNT && call.getArgList().isEmpty()) { // COUNT(*) - retVal = Aggregation.create(new CountAggregatorFactory(name)); + return Aggregation.create(new CountAggregatorFactory(name)).filter(makeFilter(filters, sourceRowSignature)); } else if (kind == SqlKind.COUNT && call.isDistinct()) { // COUNT(DISTINCT x) - retVal = approximateCountDistinct ? APPROX_COUNT_DISTINCT.toDruidAggregation( + return approximateCountDistinct ? APPROX_COUNT_DISTINCT.toDruidAggregation( name, sourceRowSignature, existingAggregations, project, - call + call, + makeFilter(filters, sourceRowSignature) ) : null; } else if (kind == SqlKind.COUNT || kind == SqlKind.SUM @@ -845,9 +845,10 @@ public class GroupByRules if (forceCount || kind == SqlKind.COUNT) { // COUNT(x) - retVal = Aggregation.create(new CountAggregatorFactory(name)); + return Aggregation.create(new CountAggregatorFactory(name)).filter(makeFilter(filters, sourceRowSignature)); } else { // Built-in aggregator that is not COUNT. + final Aggregation retVal; final String fieldName = input.getFieldName(); final String expression = input.getExpression(); @@ -888,26 +889,21 @@ public class GroupByRules // Not reached. throw new ISE("WTF?! Kind[%s] got into the built-in aggregator path somehow?!", kind); } + + return retVal.filter(makeFilter(filters, sourceRowSignature)); } } else { // Not a built-in aggregator, check operator table. final SqlAggregator sqlAggregator = operatorTable.lookupAggregator(call.getAggregation().getName()); - retVal = sqlAggregator != null ? sqlAggregator.toDruidAggregation( + return sqlAggregator != null ? sqlAggregator.toDruidAggregation( name, sourceRowSignature, existingAggregations, project, - call + call, + makeFilter(filters, sourceRowSignature) ) : null; } - - final DimFilter filter = filters.isEmpty() - ? null - : Filtration.create(new AndDimFilter(filters)) - .optimizeFilterOnly(sourceRowSignature) - .getDimFilter(); - - return retVal != null ? retVal.filter(filter) : null; } public static String dimOutputName(final int dimNumber) @@ -924,4 +920,13 @@ public class GroupByRules { return "A" + aggNumber + ":" + key; } + + private static DimFilter makeFilter(final List filters, final RowSignature sourceRowSignature) + { + return filters.isEmpty() + ? null + : Filtration.create(new AndDimFilter(filters)) + .optimizeFilterOnly(sourceRowSignature) + .getDimFilter(); + } } diff --git a/sql/src/main/java/io/druid/sql/calcite/schema/DruidSchema.java b/sql/src/main/java/io/druid/sql/calcite/schema/DruidSchema.java index d975746dfa8..b4f73a1d366 100644 --- a/sql/src/main/java/io/druid/sql/calcite/schema/DruidSchema.java +++ b/sql/src/main/java/io/druid/sql/calcite/schema/DruidSchema.java @@ -50,6 +50,7 @@ import io.druid.server.coordination.DruidServerMetadata; import io.druid.sql.calcite.planner.PlannerConfig; import io.druid.sql.calcite.rel.QueryMaker; import io.druid.sql.calcite.table.DruidTable; +import io.druid.sql.calcite.table.RowSignature; import io.druid.timeline.DataSegment; import org.apache.calcite.schema.Table; import org.apache.calcite.schema.impl.AbstractSchema; @@ -296,7 +297,7 @@ public class DruidSchema extends AbstractSchema } final Map columnMetadata = Iterables.getOnlyElement(results).getColumns(); - final Map columnValueTypes = Maps.newHashMap(); + final RowSignature.Builder rowSignature = RowSignature.builder(); for (Map.Entry entry : columnMetadata.entrySet()) { if (entry.getValue().isError()) { @@ -314,13 +315,13 @@ public class DruidSchema extends AbstractSchema valueType = ValueType.COMPLEX; } - columnValueTypes.put(entry.getKey(), valueType); + rowSignature.add(entry.getKey(), valueType); } return new DruidTable( queryMaker, new TableDataSource(dataSource), - columnValueTypes + rowSignature.build() ); } } diff --git a/sql/src/main/java/io/druid/sql/calcite/table/DruidTable.java b/sql/src/main/java/io/druid/sql/calcite/table/DruidTable.java index 34f2c8b82c0..759afb35de6 100644 --- a/sql/src/main/java/io/druid/sql/calcite/table/DruidTable.java +++ b/sql/src/main/java/io/druid/sql/calcite/table/DruidTable.java @@ -20,9 +20,7 @@ package io.druid.sql.calcite.table; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableSortedMap; import io.druid.query.DataSource; -import io.druid.segment.column.ValueType; import io.druid.sql.calcite.rel.DruidQueryRel; import io.druid.sql.calcite.rel.QueryMaker; import org.apache.calcite.plan.RelOptCluster; @@ -35,8 +33,6 @@ import org.apache.calcite.schema.Statistic; import org.apache.calcite.schema.Statistics; import org.apache.calcite.schema.TranslatableTable; -import java.util.Map; - public class DruidTable implements TranslatableTable { private final QueryMaker queryMaker; @@ -46,17 +42,12 @@ public class DruidTable implements TranslatableTable public DruidTable( final QueryMaker queryMaker, final DataSource dataSource, - final Map columns + final RowSignature rowSignature ) { this.queryMaker = Preconditions.checkNotNull(queryMaker, "queryMaker"); this.dataSource = Preconditions.checkNotNull(dataSource, "dataSource"); - - final RowSignature.Builder rowSignatureBuilder = RowSignature.builder(); - for (Map.Entry entry : ImmutableSortedMap.copyOf(columns).entrySet()) { - rowSignatureBuilder.add(entry.getKey(), entry.getValue()); - } - this.rowSignature = rowSignatureBuilder.build(); + this.rowSignature = Preconditions.checkNotNull(rowSignature, "rowSignature"); } public QueryMaker getQueryMaker() diff --git a/sql/src/main/java/io/druid/sql/calcite/table/RowSignature.java b/sql/src/main/java/io/druid/sql/calcite/table/RowSignature.java index 68e2859146a..aada90bbae3 100644 --- a/sql/src/main/java/io/druid/sql/calcite/table/RowSignature.java +++ b/sql/src/main/java/io/druid/sql/calcite/table/RowSignature.java @@ -51,7 +51,7 @@ public class RowSignature private final Map columnTypes; private final List columnNames; - public RowSignature(final List> columnTypeList) + private RowSignature(final List> columnTypeList) { final Map columnTypes0 = Maps.newHashMap(); final ImmutableList.Builder columnNamesBuilder = ImmutableList.builder(); diff --git a/sql/src/test/java/io/druid/sql/calcite/util/CalciteTests.java b/sql/src/test/java/io/druid/sql/calcite/util/CalciteTests.java index 7d27fe3cfe1..089b65980c5 100644 --- a/sql/src/test/java/io/druid/sql/calcite/util/CalciteTests.java +++ b/sql/src/test/java/io/druid/sql/calcite/util/CalciteTests.java @@ -20,6 +20,7 @@ package io.druid.sql.calcite.util; import com.google.common.base.Supplier; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -36,8 +37,6 @@ import io.druid.query.Query; import io.druid.query.QueryRunnerFactory; import io.druid.query.QueryRunnerFactoryConglomerate; import io.druid.query.QueryRunnerTestHelper; -import io.druid.query.QuerySegmentWalker; -import io.druid.query.TableDataSource; import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.CountAggregatorFactory; import io.druid.query.aggregation.DoubleSumAggregatorFactory; @@ -65,19 +64,14 @@ import io.druid.query.topn.TopNQueryRunnerFactory; import io.druid.segment.IndexBuilder; import io.druid.segment.QueryableIndex; import io.druid.segment.TestHelper; -import io.druid.segment.column.ValueType; import io.druid.segment.incremental.IncrementalIndexSchema; import io.druid.sql.calcite.aggregation.ApproxCountDistinctSqlAggregator; import io.druid.sql.calcite.aggregation.SqlAggregator; import io.druid.sql.calcite.planner.DruidOperatorTable; import io.druid.sql.calcite.planner.PlannerConfig; -import io.druid.sql.calcite.rel.QueryMaker; -import io.druid.sql.calcite.table.DruidTable; +import io.druid.sql.calcite.schema.DruidSchema; import io.druid.timeline.DataSegment; import io.druid.timeline.partition.LinearShardSpec; -import org.apache.calcite.schema.Schema; -import org.apache.calcite.schema.Table; -import org.apache.calcite.schema.impl.AbstractSchema; import org.joda.time.DateTime; import java.io.File; @@ -208,7 +202,7 @@ public class CalciteTests .withRollup(false) .build(); - private static final List ROWS1 = ImmutableList.of( + public static final List ROWS1 = ImmutableList.of( createRow(ImmutableMap.of("t", "2000-01-01", "m1", "1.0", "dim1", "", "dim2", ImmutableList.of("a"))), createRow(ImmutableMap.of("t", "2000-01-02", "m1", "2.0", "dim1", "10.1", "dim2", ImmutableList.of())), createRow(ImmutableMap.of("t", "2000-01-03", "m1", "3.0", "dim1", "2", "dim2", ImmutableList.of(""))), @@ -217,21 +211,12 @@ public class CalciteTests createRow(ImmutableMap.of("t", "2001-01-03", "m1", "6.0", "dim1", "abc")) ); - private static final List ROWS2 = ImmutableList.of( + public static final List ROWS2 = ImmutableList.of( createRow("2000-01-01", "דרואיד", "he", 1.0), createRow("2000-01-01", "druid", "en", 1.0), createRow("2000-01-01", "друид", "ru", 1.0) ); - private static final Map COLUMN_TYPES = ImmutableMap.builder() - .put("__time", ValueType.LONG) - .put("cnt", ValueType.LONG) - .put("dim1", ValueType.STRING) - .put("dim2", ValueType.STRING) - .put("m1", ValueType.FLOAT) - .put("unique_dim1", ValueType.COMPLEX) - .build(); - private CalciteTests() { // No instantiation. @@ -282,23 +267,27 @@ public class CalciteTests return new DruidOperatorTable(ImmutableSet.of(new ApproxCountDistinctSqlAggregator())); } - public static Schema createMockSchema(final QuerySegmentWalker walker, final PlannerConfig plannerConfig) + public static DruidSchema createMockSchema( + final SpecificSegmentsQuerySegmentWalker walker, + final PlannerConfig plannerConfig + ) { - final QueryMaker queryMaker = new QueryMaker(walker, plannerConfig); - final DruidTable druidTable1 = new DruidTable(queryMaker, new TableDataSource(DATASOURCE1), COLUMN_TYPES); - final DruidTable druidTable2 = new DruidTable(queryMaker, new TableDataSource(DATASOURCE2), COLUMN_TYPES); - final Map tableMap = ImmutableMap.of( - DATASOURCE1, druidTable1, - DATASOURCE2, druidTable2 + final DruidSchema schema = new DruidSchema( + walker, + new TestServerInventoryView(walker.getSegments()), + plannerConfig ); - return new AbstractSchema() - { - @Override - protected Map getTableMap() - { - return tableMap; - } - }; + + schema.start(); + try { + schema.awaitInitialization(); + } + catch (InterruptedException e) { + throw Throwables.propagate(e); + } + + schema.stop(); + return schema; } public static InputRow createRow(final ImmutableMap map)