SQL: Add resolution parameter, fix filtering bug with APPROX_QUANTILE (#3868)

* SQL: Add resolution parameter to quantile agg, rename to APPROX_QUANTILE.

* Fix bug with re-use of filtered approximate histogram aggregators.

Also add APPROX_QUANTILE tests for filtering and running on complex columns.
Includes some slight refactoring to allow tests to make DruidTables that
include complex columns.

* Remove unused import
This commit is contained in:
Gian Merlino 2017-01-25 18:39:26 -08:00 committed by Jonathan Wei
parent 75d9e5e7a7
commit ac84a3e011
12 changed files with 332 additions and 110 deletions

View File

@ -53,6 +53,7 @@ import io.druid.sql.calcite.planner.PlannerFactory;
import io.druid.sql.calcite.planner.PlannerResult;
import io.druid.sql.calcite.rel.QueryMaker;
import io.druid.sql.calcite.table.DruidTable;
import io.druid.sql.calcite.table.RowSignature;
import io.druid.sql.calcite.util.CalciteTests;
import io.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import io.druid.timeline.DataSegment;
@ -158,12 +159,12 @@ public class SqlBenchmark
new DruidTable(
new QueryMaker(walker, plannerConfig),
new TableDataSource("foo"),
ImmutableMap.of(
"__time", ValueType.LONG,
"dimSequential", ValueType.STRING,
"dimZipf", ValueType.STRING,
"dimUniform", ValueType.STRING
)
RowSignature.builder()
.add("__time", ValueType.LONG)
.add("dimSequential", ValueType.STRING)
.add("dimZipf", ValueType.STRING)
.add("dimUniform", ValueType.STRING)
.build()
)
);
final Schema druidSchema = new AbstractSchema()

View File

@ -149,8 +149,10 @@ Some Druid extensions also include SQL language extensions.
If the [approximate histogram extension](../development/extensions-core/approximate-histograms.html) is loaded:
- `QUANTILE(column, probability)` on numeric or approximate histogram columns computes approximate quantiles. The
"probability" should be between 0 and 1 (exclusive).
- `APPROX_QUANTILE(column, probability)` or `APPROX_QUANTILE(column, probability, resolution)` on numeric or
approximate histogram columns computes approximate quantiles. The "probability" should be between 0 and 1 (exclusive).
The "resolution" is the number of centroids to use for the computation. Higher resolutions will be give more
precise results but also have higher overhead. If not provided, the default resolution is 50.
### Unsupported features

View File

@ -19,14 +19,17 @@
package io.druid.query.aggregation.histogram.sql;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.histogram.ApproximateHistogram;
import io.druid.query.aggregation.histogram.ApproximateHistogramAggregatorFactory;
import io.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory;
import io.druid.query.aggregation.histogram.QuantilePostAggregator;
import io.druid.query.filter.DimFilter;
import io.druid.segment.column.ValueType;
import io.druid.sql.calcite.aggregation.Aggregation;
import io.druid.sql.calcite.aggregation.Aggregations;
import io.druid.sql.calcite.aggregation.SqlAggregator;
import io.druid.sql.calcite.expression.Expressions;
import io.druid.sql.calcite.expression.RowExtraction;
@ -48,7 +51,7 @@ import java.util.List;
public class QuantileSqlAggregator implements SqlAggregator
{
private static final SqlAggFunction FUNCTION_INSTANCE = new QuantileSqlAggFunction();
private static final String NAME = "QUANTILE";
private static final String NAME = "APPROX_QUANTILE";
@Override
public SqlAggFunction calciteFunction()
@ -62,7 +65,8 @@ public class QuantileSqlAggregator implements SqlAggregator
final RowSignature rowSignature,
final List<Aggregation> existingAggregations,
final Project project,
final AggregateCall aggregateCall
final AggregateCall aggregateCall,
final DimFilter filter
)
{
final RowExtraction rex = Expressions.toRowExtraction(
@ -77,6 +81,8 @@ public class QuantileSqlAggregator implements SqlAggregator
return null;
}
final AggregatorFactory aggregatorFactory;
final String histogramName = String.format("%s:agg", name);
final RexNode probabilityArg = Expressions.fromFieldAccess(
rowSignature,
project,
@ -84,10 +90,18 @@ public class QuantileSqlAggregator implements SqlAggregator
);
final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue();
final AggregatorFactory aggregatorFactory;
final String histogramName = String.format("%s:agg", name);
final int resolution;
if (aggregateCall.getArgList().size() >= 3) {
final RexNode resolutionArg = Expressions.fromFieldAccess(
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
resolution = ((Number) RexLiteral.value(resolutionArg)).intValue();
} else {
resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE;
}
final int resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE;
final int numBuckets = ApproximateHistogram.DEFAULT_BUCKET_SIZE;
final float lowerLimit = Float.NEGATIVE_INFINITY;
final float upperLimit = Float.POSITIVE_INFINITY;
@ -95,19 +109,30 @@ public class QuantileSqlAggregator implements SqlAggregator
// Look for existing matching aggregatorFactory.
for (final Aggregation existing : existingAggregations) {
for (AggregatorFactory factory : existing.getAggregatorFactories()) {
if (factory instanceof ApproximateHistogramAggregatorFactory) {
final ApproximateHistogramAggregatorFactory theFactory = (ApproximateHistogramAggregatorFactory) factory;
if (theFactory.getFieldName().equals(rex.getColumn())
&& theFactory.getResolution() == resolution
&& theFactory.getNumBuckets() == numBuckets
&& theFactory.getLowerLimit() == lowerLimit
&& theFactory.getUpperLimit() == upperLimit) {
// Found existing one. Use this.
return Aggregation.create(
ImmutableList.<AggregatorFactory>of(),
new QuantilePostAggregator(name, theFactory.getName(), probability)
);
}
final boolean matches = Aggregations.aggregatorMatches(
factory,
filter,
ApproximateHistogramAggregatorFactory.class,
new Predicate<ApproximateHistogramAggregatorFactory>()
{
@Override
public boolean apply(final ApproximateHistogramAggregatorFactory theFactory)
{
return theFactory.getFieldName().equals(rex.getColumn())
&& theFactory.getResolution() == resolution
&& theFactory.getNumBuckets() == numBuckets
&& theFactory.getLowerLimit() == lowerLimit
&& theFactory.getUpperLimit() == upperLimit;
}
}
);
if (matches) {
// Found existing one. Use this.
return Aggregation.create(
ImmutableList.<AggregatorFactory>of(),
new QuantilePostAggregator(name, factory.getName(), probability)
);
}
}
}
@ -135,12 +160,13 @@ public class QuantileSqlAggregator implements SqlAggregator
return Aggregation.create(
ImmutableList.of(aggregatorFactory),
new QuantilePostAggregator(name, histogramName, probability)
);
).filter(filter);
}
private static class QuantileSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE = "'" + NAME + "(column, probability)'";
private static final String SIGNATURE1 = "'" + NAME + "(column, probability)'\n";
private static final String SIGNATURE2 = "'" + NAME + "(column, probability, resolution)'\n";
QuantileSqlAggFunction()
{
@ -150,9 +176,15 @@ public class QuantileSqlAggregator implements SqlAggregator
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.DOUBLE),
null,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
OperandTypes.or(
OperandTypes.and(
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
),
OperandTypes.and(
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
)
),
SqlFunctionCategory.NUMERIC,
false,

View File

@ -26,10 +26,22 @@ import io.druid.granularity.QueryGranularities;
import io.druid.java.util.common.guava.Sequences;
import io.druid.query.Druids;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.aggregation.DoubleSumAggregatorFactory;
import io.druid.query.aggregation.FilteredAggregatorFactory;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.aggregation.histogram.ApproximateHistogramAggregatorFactory;
import io.druid.query.aggregation.histogram.ApproximateHistogramDruidModule;
import io.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory;
import io.druid.query.aggregation.histogram.QuantilePostAggregator;
import io.druid.query.filter.NotDimFilter;
import io.druid.query.filter.SelectorDimFilter;
import io.druid.query.spec.MultipleIntervalSegmentSpec;
import io.druid.segment.IndexBuilder;
import io.druid.segment.QueryableIndex;
import io.druid.segment.TestHelper;
import io.druid.segment.incremental.IncrementalIndexSchema;
import io.druid.sql.calcite.CalciteQueryTest;
import io.druid.sql.calcite.aggregation.SqlAggregator;
import io.druid.sql.calcite.filtration.Filtration;
import io.druid.sql.calcite.planner.Calcites;
@ -40,6 +52,8 @@ import io.druid.sql.calcite.planner.PlannerResult;
import io.druid.sql.calcite.util.CalciteTests;
import io.druid.sql.calcite.util.QueryLogHook;
import io.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import io.druid.timeline.DataSegment;
import io.druid.timeline.partition.LinearShardSpec;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.tools.Planner;
import org.junit.After;
@ -52,10 +66,10 @@ import org.junit.rules.TemporaryFolder;
import java.util.ArrayList;
import java.util.List;
import static io.druid.sql.calcite.CalciteQueryTest.TIMESERIES_CONTEXT;
public class QuantileSqlAggregatorTest
{
private static final String DATA_SOURCE = "foo";
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@ -69,7 +83,45 @@ public class QuantileSqlAggregatorTest
public void setUp() throws Exception
{
Calcites.setSystemProperties();
walker = CalciteTests.createMockWalker(temporaryFolder.newFolder());
// Note: this is needed in order to properly register the serde for Histogram.
new ApproximateHistogramDruidModule().configure(null);
final QueryableIndex index = IndexBuilder.create()
.tmpDir(temporaryFolder.newFolder())
.indexMerger(TestHelper.getTestIndexMergerV9())
.schema(
new IncrementalIndexSchema.Builder()
.withMetrics(
new AggregatorFactory[]{
new CountAggregatorFactory("cnt"),
new DoubleSumAggregatorFactory("m1", "m1"),
new ApproximateHistogramAggregatorFactory(
"hist_m1",
"m1",
null,
null,
null,
null
)
}
)
.withRollup(false)
.build()
)
.rows(CalciteTests.ROWS1)
.buildMMappedIndex();
walker = new SpecificSegmentsQuerySegmentWalker(CalciteTests.queryRunnerFactoryConglomerate()).add(
DataSegment.builder()
.dataSource(DATA_SOURCE)
.interval(index.getDataInterval())
.version("1")
.shardSpec(new LinearShardSpec(0))
.build(),
index
);
final PlannerConfig plannerConfig = new PlannerConfig();
final SchemaPlus rootSchema = Calcites.createRootSchema(
CalciteTests.createMockSchema(
@ -96,13 +148,23 @@ public class QuantileSqlAggregatorTest
public void testQuantileOnFloatAndLongs() throws Exception
{
try (final Planner planner = plannerFactory.createPlanner()) {
final String sql = "SELECT QUANTILE(m1, 0.01), QUANTILE(m1, 0.5), QUANTILE(m1, 0.99), QUANTILE(cnt, 0.5) FROM foo";
final String sql = "SELECT\n"
+ "APPROX_QUANTILE(m1, 0.01),\n"
+ "APPROX_QUANTILE(m1, 0.5, 50),\n"
+ "APPROX_QUANTILE(m1, 0.98, 200),\n"
+ "APPROX_QUANTILE(m1, 0.99),\n"
+ "APPROX_QUANTILE(m1, 0.99) FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE(m1, 0.999) FILTER(WHERE dim1 <> 'abc'),\n"
+ "APPROX_QUANTILE(m1, 0.999) FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE(cnt, 0.5)\n"
+ "FROM foo";
final PlannerResult plannerResult = Calcites.plan(planner, sql);
// Verify results
final List<Object[]> results = Sequences.toList(plannerResult.run(), new ArrayList<Object[]>());
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{1.0, 3.0, 5.940000057220459, 1.0}
new Object[]{1.0, 3.0, 5.880000114440918, 5.940000057220459, 6.0, 4.994999885559082, 6.0, 1.0}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
@ -115,17 +177,90 @@ public class QuantileSqlAggregatorTest
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(QueryGranularities.ALL)
.aggregators(ImmutableList.<AggregatorFactory>of(
.aggregators(ImmutableList.of(
new ApproximateHistogramAggregatorFactory("a0:agg", "m1", null, null, null, null),
new ApproximateHistogramAggregatorFactory("a3:agg", "cnt", null, null, null, null)
new ApproximateHistogramAggregatorFactory("a2:agg", "m1", 200, null, null, null),
new FilteredAggregatorFactory(
new ApproximateHistogramAggregatorFactory("a4:agg", "m1", null, null, null, null),
new SelectorDimFilter("dim1", "abc", null)
),
new FilteredAggregatorFactory(
new ApproximateHistogramAggregatorFactory("a5:agg", "m1", null, null, null, null),
new NotDimFilter(new SelectorDimFilter("dim1", "abc", null))
),
new ApproximateHistogramAggregatorFactory("a7:agg", "cnt", null, null, null, null)
))
.postAggregators(ImmutableList.<PostAggregator>of(
new QuantilePostAggregator("a0", "a0:agg", 0.01f),
new QuantilePostAggregator("a1", "a0:agg", 0.50f),
new QuantilePostAggregator("a2", "a0:agg", 0.99f),
new QuantilePostAggregator("a3", "a3:agg", 0.50f)
new QuantilePostAggregator("a2", "a2:agg", 0.98f),
new QuantilePostAggregator("a3", "a0:agg", 0.99f),
new QuantilePostAggregator("a4", "a4:agg", 0.99f),
new QuantilePostAggregator("a5", "a5:agg", 0.999f),
new QuantilePostAggregator("a6", "a4:agg", 0.999f),
new QuantilePostAggregator("a7", "a7:agg", 0.50f)
))
.context(TIMESERIES_CONTEXT)
.context(CalciteQueryTest.TIMESERIES_CONTEXT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
}
@Test
public void testQuantileOnComplexColumn() throws Exception
{
try (final Planner planner = plannerFactory.createPlanner()) {
final String sql = "SELECT\n"
+ "APPROX_QUANTILE(hist_m1, 0.01),\n"
+ "APPROX_QUANTILE(hist_m1, 0.5, 50),\n"
+ "APPROX_QUANTILE(hist_m1, 0.98, 200),\n"
+ "APPROX_QUANTILE(hist_m1, 0.99),\n"
+ "APPROX_QUANTILE(hist_m1, 0.99) FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE(hist_m1, 0.999) FILTER(WHERE dim1 <> 'abc'),\n"
+ "APPROX_QUANTILE(hist_m1, 0.999) FILTER(WHERE dim1 = 'abc')\n"
+ "FROM foo";
final PlannerResult plannerResult = Calcites.plan(planner, sql);
// Verify results
final List<Object[]> results = Sequences.toList(plannerResult.run(), new ArrayList<Object[]>());
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{1.0, 3.0, 5.880000114440918, 5.940000057220459, 6.0, 4.994999885559082, 6.0}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
// Verify query
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(QueryGranularities.ALL)
.aggregators(ImmutableList.of(
new ApproximateHistogramFoldingAggregatorFactory("a0:agg", "hist_m1", null, null, null, null),
new ApproximateHistogramFoldingAggregatorFactory("a2:agg", "hist_m1", 200, null, null, null),
new FilteredAggregatorFactory(
new ApproximateHistogramFoldingAggregatorFactory("a4:agg", "hist_m1", null, null, null, null),
new SelectorDimFilter("dim1", "abc", null)
),
new FilteredAggregatorFactory(
new ApproximateHistogramFoldingAggregatorFactory("a5:agg", "hist_m1", null, null, null, null),
new NotDimFilter(new SelectorDimFilter("dim1", "abc", null))
)
))
.postAggregators(ImmutableList.<PostAggregator>of(
new QuantilePostAggregator("a0", "a0:agg", 0.01f),
new QuantilePostAggregator("a1", "a0:agg", 0.50f),
new QuantilePostAggregator("a2", "a2:agg", 0.98f),
new QuantilePostAggregator("a3", "a0:agg", 0.99f),
new QuantilePostAggregator("a4", "a4:agg", 0.99f),
new QuantilePostAggregator("a5", "a5:agg", 0.999f),
new QuantilePostAggregator("a6", "a4:agg", 0.999f)
))
.context(CalciteQueryTest.TIMESERIES_CONTEXT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);

View File

@ -0,0 +1,60 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.sql.calcite.aggregation;
import com.google.common.base.Predicate;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.FilteredAggregatorFactory;
import io.druid.query.filter.DimFilter;
public class Aggregations
{
private Aggregations()
{
// No instantiation.
}
/**
* Returns true if "factory" is an aggregator factory that either matches "predicate" (if filter is null) or is
* a filtered aggregator factory whose filter is equal to "filter" and underlying aggregator matches "predicate".
*
* @param factory factory to match
* @param filter filter, may be null
* @param clazz class of factory to match
* @param predicate predicate
*
* @return true if the aggregator matches filter + predicate
*/
public static <T extends AggregatorFactory> boolean aggregatorMatches(
final AggregatorFactory factory,
final DimFilter filter,
final Class<T> clazz,
final Predicate<T> predicate
)
{
if (filter != null) {
return factory instanceof FilteredAggregatorFactory &&
((FilteredAggregatorFactory) factory).getFilter().equals(filter)
&& aggregatorMatches(((FilteredAggregatorFactory) factory).getAggregator(), null, clazz, predicate);
} else {
return clazz.isAssignableFrom(factory.getClass()) && predicate.apply(clazz.cast(factory));
}
}
}

View File

@ -27,6 +27,7 @@ import io.druid.query.aggregation.cardinality.CardinalityAggregatorFactory;
import io.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
import io.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import io.druid.query.dimension.DimensionSpec;
import io.druid.query.filter.DimFilter;
import io.druid.segment.column.ValueType;
import io.druid.sql.calcite.expression.Expressions;
import io.druid.sql.calcite.expression.RowExtraction;
@ -60,7 +61,8 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
final RowSignature rowSignature,
final List<Aggregation> existingAggregations,
final Project project,
final AggregateCall aggregateCall
final AggregateCall aggregateCall,
final DimFilter filter
)
{
final RowExtraction rex = Expressions.toRowExtraction(
@ -99,7 +101,7 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
return new HyperUniqueFinalizingPostAggregator(outputName, name);
}
}
);
).filter(filter);
}
private static class ApproxCountDistinctSqlAggFunction extends SqlAggFunction

View File

@ -19,6 +19,7 @@
package io.druid.sql.calcite.aggregation;
import io.druid.query.filter.DimFilter;
import io.druid.sql.calcite.table.RowSignature;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
@ -44,9 +45,11 @@ public interface SqlAggregator
*
* @param name desired output name of the aggregation
* @param rowSignature signature of the rows being aggregated
* @param existingAggregations existing aggregations for this query; useful for re-using aggregators
* @param project SQL projection to apply before the aggregate call
* @param existingAggregations existing aggregations for this query; useful for re-using aggregations. May be safely
* ignored if you do not want to re-use existing aggregations.
* @param project SQL projection to apply before the aggregate call, may be null
* @param aggregateCall SQL aggregate call
* @param filter filter that should be applied to the aggregation, may be null
*
* @return aggregation, or null if the call cannot be translated
*/
@ -56,6 +59,7 @@ public interface SqlAggregator
final RowSignature rowSignature,
final List<Aggregation> existingAggregations,
final Project project,
final AggregateCall aggregateCall
final AggregateCall aggregateCall,
final DimFilter filter
);
}

View File

@ -739,7 +739,6 @@ public class GroupByRules
final String name = aggOutputName(aggNumber);
final SqlKind kind = call.getAggregation().getKind();
final SqlTypeName outputType = call.getType().getSqlTypeName();
final Aggregation retVal;
if (call.filterArg >= 0) {
// AGG(xxx) FILTER(WHERE yyy)
@ -759,15 +758,16 @@ public class GroupByRules
if (kind == SqlKind.COUNT && call.getArgList().isEmpty()) {
// COUNT(*)
retVal = Aggregation.create(new CountAggregatorFactory(name));
return Aggregation.create(new CountAggregatorFactory(name)).filter(makeFilter(filters, sourceRowSignature));
} else if (kind == SqlKind.COUNT && call.isDistinct()) {
// COUNT(DISTINCT x)
retVal = approximateCountDistinct ? APPROX_COUNT_DISTINCT.toDruidAggregation(
return approximateCountDistinct ? APPROX_COUNT_DISTINCT.toDruidAggregation(
name,
sourceRowSignature,
existingAggregations,
project,
call
call,
makeFilter(filters, sourceRowSignature)
) : null;
} else if (kind == SqlKind.COUNT
|| kind == SqlKind.SUM
@ -845,9 +845,10 @@ public class GroupByRules
if (forceCount || kind == SqlKind.COUNT) {
// COUNT(x)
retVal = Aggregation.create(new CountAggregatorFactory(name));
return Aggregation.create(new CountAggregatorFactory(name)).filter(makeFilter(filters, sourceRowSignature));
} else {
// Built-in aggregator that is not COUNT.
final Aggregation retVal;
final String fieldName = input.getFieldName();
final String expression = input.getExpression();
@ -888,26 +889,21 @@ public class GroupByRules
// Not reached.
throw new ISE("WTF?! Kind[%s] got into the built-in aggregator path somehow?!", kind);
}
return retVal.filter(makeFilter(filters, sourceRowSignature));
}
} else {
// Not a built-in aggregator, check operator table.
final SqlAggregator sqlAggregator = operatorTable.lookupAggregator(call.getAggregation().getName());
retVal = sqlAggregator != null ? sqlAggregator.toDruidAggregation(
return sqlAggregator != null ? sqlAggregator.toDruidAggregation(
name,
sourceRowSignature,
existingAggregations,
project,
call
call,
makeFilter(filters, sourceRowSignature)
) : null;
}
final DimFilter filter = filters.isEmpty()
? null
: Filtration.create(new AndDimFilter(filters))
.optimizeFilterOnly(sourceRowSignature)
.getDimFilter();
return retVal != null ? retVal.filter(filter) : null;
}
public static String dimOutputName(final int dimNumber)
@ -924,4 +920,13 @@ public class GroupByRules
{
return "A" + aggNumber + ":" + key;
}
private static DimFilter makeFilter(final List<DimFilter> filters, final RowSignature sourceRowSignature)
{
return filters.isEmpty()
? null
: Filtration.create(new AndDimFilter(filters))
.optimizeFilterOnly(sourceRowSignature)
.getDimFilter();
}
}

View File

@ -50,6 +50,7 @@ import io.druid.server.coordination.DruidServerMetadata;
import io.druid.sql.calcite.planner.PlannerConfig;
import io.druid.sql.calcite.rel.QueryMaker;
import io.druid.sql.calcite.table.DruidTable;
import io.druid.sql.calcite.table.RowSignature;
import io.druid.timeline.DataSegment;
import org.apache.calcite.schema.Table;
import org.apache.calcite.schema.impl.AbstractSchema;
@ -296,7 +297,7 @@ public class DruidSchema extends AbstractSchema
}
final Map<String, ColumnAnalysis> columnMetadata = Iterables.getOnlyElement(results).getColumns();
final Map<String, ValueType> columnValueTypes = Maps.newHashMap();
final RowSignature.Builder rowSignature = RowSignature.builder();
for (Map.Entry<String, ColumnAnalysis> entry : columnMetadata.entrySet()) {
if (entry.getValue().isError()) {
@ -314,13 +315,13 @@ public class DruidSchema extends AbstractSchema
valueType = ValueType.COMPLEX;
}
columnValueTypes.put(entry.getKey(), valueType);
rowSignature.add(entry.getKey(), valueType);
}
return new DruidTable(
queryMaker,
new TableDataSource(dataSource),
columnValueTypes
rowSignature.build()
);
}
}

View File

@ -20,9 +20,7 @@
package io.druid.sql.calcite.table;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSortedMap;
import io.druid.query.DataSource;
import io.druid.segment.column.ValueType;
import io.druid.sql.calcite.rel.DruidQueryRel;
import io.druid.sql.calcite.rel.QueryMaker;
import org.apache.calcite.plan.RelOptCluster;
@ -35,8 +33,6 @@ import org.apache.calcite.schema.Statistic;
import org.apache.calcite.schema.Statistics;
import org.apache.calcite.schema.TranslatableTable;
import java.util.Map;
public class DruidTable implements TranslatableTable
{
private final QueryMaker queryMaker;
@ -46,17 +42,12 @@ public class DruidTable implements TranslatableTable
public DruidTable(
final QueryMaker queryMaker,
final DataSource dataSource,
final Map<String, ValueType> columns
final RowSignature rowSignature
)
{
this.queryMaker = Preconditions.checkNotNull(queryMaker, "queryMaker");
this.dataSource = Preconditions.checkNotNull(dataSource, "dataSource");
final RowSignature.Builder rowSignatureBuilder = RowSignature.builder();
for (Map.Entry<String, ValueType> entry : ImmutableSortedMap.copyOf(columns).entrySet()) {
rowSignatureBuilder.add(entry.getKey(), entry.getValue());
}
this.rowSignature = rowSignatureBuilder.build();
this.rowSignature = Preconditions.checkNotNull(rowSignature, "rowSignature");
}
public QueryMaker getQueryMaker()

View File

@ -51,7 +51,7 @@ public class RowSignature
private final Map<String, ValueType> columnTypes;
private final List<String> columnNames;
public RowSignature(final List<Pair<String, ValueType>> columnTypeList)
private RowSignature(final List<Pair<String, ValueType>> columnTypeList)
{
final Map<String, ValueType> columnTypes0 = Maps.newHashMap();
final ImmutableList.Builder<String> columnNamesBuilder = ImmutableList.builder();

View File

@ -20,6 +20,7 @@
package io.druid.sql.calcite.util;
import com.google.common.base.Supplier;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
@ -36,8 +37,6 @@ import io.druid.query.Query;
import io.druid.query.QueryRunnerFactory;
import io.druid.query.QueryRunnerFactoryConglomerate;
import io.druid.query.QueryRunnerTestHelper;
import io.druid.query.QuerySegmentWalker;
import io.druid.query.TableDataSource;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.aggregation.DoubleSumAggregatorFactory;
@ -65,19 +64,14 @@ import io.druid.query.topn.TopNQueryRunnerFactory;
import io.druid.segment.IndexBuilder;
import io.druid.segment.QueryableIndex;
import io.druid.segment.TestHelper;
import io.druid.segment.column.ValueType;
import io.druid.segment.incremental.IncrementalIndexSchema;
import io.druid.sql.calcite.aggregation.ApproxCountDistinctSqlAggregator;
import io.druid.sql.calcite.aggregation.SqlAggregator;
import io.druid.sql.calcite.planner.DruidOperatorTable;
import io.druid.sql.calcite.planner.PlannerConfig;
import io.druid.sql.calcite.rel.QueryMaker;
import io.druid.sql.calcite.table.DruidTable;
import io.druid.sql.calcite.schema.DruidSchema;
import io.druid.timeline.DataSegment;
import io.druid.timeline.partition.LinearShardSpec;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.Table;
import org.apache.calcite.schema.impl.AbstractSchema;
import org.joda.time.DateTime;
import java.io.File;
@ -208,7 +202,7 @@ public class CalciteTests
.withRollup(false)
.build();
private static final List<InputRow> ROWS1 = ImmutableList.of(
public static final List<InputRow> ROWS1 = ImmutableList.of(
createRow(ImmutableMap.of("t", "2000-01-01", "m1", "1.0", "dim1", "", "dim2", ImmutableList.of("a"))),
createRow(ImmutableMap.of("t", "2000-01-02", "m1", "2.0", "dim1", "10.1", "dim2", ImmutableList.of())),
createRow(ImmutableMap.of("t", "2000-01-03", "m1", "3.0", "dim1", "2", "dim2", ImmutableList.of(""))),
@ -217,21 +211,12 @@ public class CalciteTests
createRow(ImmutableMap.of("t", "2001-01-03", "m1", "6.0", "dim1", "abc"))
);
private static final List<InputRow> ROWS2 = ImmutableList.of(
public static final List<InputRow> ROWS2 = ImmutableList.of(
createRow("2000-01-01", "דרואיד", "he", 1.0),
createRow("2000-01-01", "druid", "en", 1.0),
createRow("2000-01-01", "друид", "ru", 1.0)
);
private static final Map<String, ValueType> COLUMN_TYPES = ImmutableMap.<String, ValueType>builder()
.put("__time", ValueType.LONG)
.put("cnt", ValueType.LONG)
.put("dim1", ValueType.STRING)
.put("dim2", ValueType.STRING)
.put("m1", ValueType.FLOAT)
.put("unique_dim1", ValueType.COMPLEX)
.build();
private CalciteTests()
{
// No instantiation.
@ -282,23 +267,27 @@ public class CalciteTests
return new DruidOperatorTable(ImmutableSet.<SqlAggregator>of(new ApproxCountDistinctSqlAggregator()));
}
public static Schema createMockSchema(final QuerySegmentWalker walker, final PlannerConfig plannerConfig)
public static DruidSchema createMockSchema(
final SpecificSegmentsQuerySegmentWalker walker,
final PlannerConfig plannerConfig
)
{
final QueryMaker queryMaker = new QueryMaker(walker, plannerConfig);
final DruidTable druidTable1 = new DruidTable(queryMaker, new TableDataSource(DATASOURCE1), COLUMN_TYPES);
final DruidTable druidTable2 = new DruidTable(queryMaker, new TableDataSource(DATASOURCE2), COLUMN_TYPES);
final Map<String, Table> tableMap = ImmutableMap.<String, Table>of(
DATASOURCE1, druidTable1,
DATASOURCE2, druidTable2
final DruidSchema schema = new DruidSchema(
walker,
new TestServerInventoryView(walker.getSegments()),
plannerConfig
);
return new AbstractSchema()
{
@Override
protected Map<String, Table> getTableMap()
{
return tableMap;
}
};
schema.start();
try {
schema.awaitInitialization();
}
catch (InterruptedException e) {
throw Throwables.propagate(e);
}
schema.stop();
return schema;
}
public static InputRow createRow(final ImmutableMap<String, ?> map)