mirror of https://github.com/apache/druid.git
SQL: Add resolution parameter, fix filtering bug with APPROX_QUANTILE (#3868)
* SQL: Add resolution parameter to quantile agg, rename to APPROX_QUANTILE. * Fix bug with re-use of filtered approximate histogram aggregators. Also add APPROX_QUANTILE tests for filtering and running on complex columns. Includes some slight refactoring to allow tests to make DruidTables that include complex columns. * Remove unused import
This commit is contained in:
parent
75d9e5e7a7
commit
ac84a3e011
|
@ -53,6 +53,7 @@ import io.druid.sql.calcite.planner.PlannerFactory;
|
|||
import io.druid.sql.calcite.planner.PlannerResult;
|
||||
import io.druid.sql.calcite.rel.QueryMaker;
|
||||
import io.druid.sql.calcite.table.DruidTable;
|
||||
import io.druid.sql.calcite.table.RowSignature;
|
||||
import io.druid.sql.calcite.util.CalciteTests;
|
||||
import io.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
|
||||
import io.druid.timeline.DataSegment;
|
||||
|
@ -158,12 +159,12 @@ public class SqlBenchmark
|
|||
new DruidTable(
|
||||
new QueryMaker(walker, plannerConfig),
|
||||
new TableDataSource("foo"),
|
||||
ImmutableMap.of(
|
||||
"__time", ValueType.LONG,
|
||||
"dimSequential", ValueType.STRING,
|
||||
"dimZipf", ValueType.STRING,
|
||||
"dimUniform", ValueType.STRING
|
||||
)
|
||||
RowSignature.builder()
|
||||
.add("__time", ValueType.LONG)
|
||||
.add("dimSequential", ValueType.STRING)
|
||||
.add("dimZipf", ValueType.STRING)
|
||||
.add("dimUniform", ValueType.STRING)
|
||||
.build()
|
||||
)
|
||||
);
|
||||
final Schema druidSchema = new AbstractSchema()
|
||||
|
|
|
@ -149,8 +149,10 @@ Some Druid extensions also include SQL language extensions.
|
|||
|
||||
If the [approximate histogram extension](../development/extensions-core/approximate-histograms.html) is loaded:
|
||||
|
||||
- `QUANTILE(column, probability)` on numeric or approximate histogram columns computes approximate quantiles. The
|
||||
"probability" should be between 0 and 1 (exclusive).
|
||||
- `APPROX_QUANTILE(column, probability)` or `APPROX_QUANTILE(column, probability, resolution)` on numeric or
|
||||
approximate histogram columns computes approximate quantiles. The "probability" should be between 0 and 1 (exclusive).
|
||||
The "resolution" is the number of centroids to use for the computation. Higher resolutions will be give more
|
||||
precise results but also have higher overhead. If not provided, the default resolution is 50.
|
||||
|
||||
### Unsupported features
|
||||
|
||||
|
|
|
@ -19,14 +19,17 @@
|
|||
|
||||
package io.druid.query.aggregation.histogram.sql;
|
||||
|
||||
import com.google.common.base.Predicate;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import io.druid.query.aggregation.AggregatorFactory;
|
||||
import io.druid.query.aggregation.histogram.ApproximateHistogram;
|
||||
import io.druid.query.aggregation.histogram.ApproximateHistogramAggregatorFactory;
|
||||
import io.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory;
|
||||
import io.druid.query.aggregation.histogram.QuantilePostAggregator;
|
||||
import io.druid.query.filter.DimFilter;
|
||||
import io.druid.segment.column.ValueType;
|
||||
import io.druid.sql.calcite.aggregation.Aggregation;
|
||||
import io.druid.sql.calcite.aggregation.Aggregations;
|
||||
import io.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import io.druid.sql.calcite.expression.Expressions;
|
||||
import io.druid.sql.calcite.expression.RowExtraction;
|
||||
|
@ -48,7 +51,7 @@ import java.util.List;
|
|||
public class QuantileSqlAggregator implements SqlAggregator
|
||||
{
|
||||
private static final SqlAggFunction FUNCTION_INSTANCE = new QuantileSqlAggFunction();
|
||||
private static final String NAME = "QUANTILE";
|
||||
private static final String NAME = "APPROX_QUANTILE";
|
||||
|
||||
@Override
|
||||
public SqlAggFunction calciteFunction()
|
||||
|
@ -62,7 +65,8 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
final RowSignature rowSignature,
|
||||
final List<Aggregation> existingAggregations,
|
||||
final Project project,
|
||||
final AggregateCall aggregateCall
|
||||
final AggregateCall aggregateCall,
|
||||
final DimFilter filter
|
||||
)
|
||||
{
|
||||
final RowExtraction rex = Expressions.toRowExtraction(
|
||||
|
@ -77,6 +81,8 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
return null;
|
||||
}
|
||||
|
||||
final AggregatorFactory aggregatorFactory;
|
||||
final String histogramName = String.format("%s:agg", name);
|
||||
final RexNode probabilityArg = Expressions.fromFieldAccess(
|
||||
rowSignature,
|
||||
project,
|
||||
|
@ -84,10 +90,18 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
);
|
||||
final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue();
|
||||
|
||||
final AggregatorFactory aggregatorFactory;
|
||||
final String histogramName = String.format("%s:agg", name);
|
||||
final int resolution;
|
||||
if (aggregateCall.getArgList().size() >= 3) {
|
||||
final RexNode resolutionArg = Expressions.fromFieldAccess(
|
||||
rowSignature,
|
||||
project,
|
||||
aggregateCall.getArgList().get(2)
|
||||
);
|
||||
resolution = ((Number) RexLiteral.value(resolutionArg)).intValue();
|
||||
} else {
|
||||
resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE;
|
||||
}
|
||||
|
||||
final int resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE;
|
||||
final int numBuckets = ApproximateHistogram.DEFAULT_BUCKET_SIZE;
|
||||
final float lowerLimit = Float.NEGATIVE_INFINITY;
|
||||
final float upperLimit = Float.POSITIVE_INFINITY;
|
||||
|
@ -95,22 +109,33 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
// Look for existing matching aggregatorFactory.
|
||||
for (final Aggregation existing : existingAggregations) {
|
||||
for (AggregatorFactory factory : existing.getAggregatorFactories()) {
|
||||
if (factory instanceof ApproximateHistogramAggregatorFactory) {
|
||||
final ApproximateHistogramAggregatorFactory theFactory = (ApproximateHistogramAggregatorFactory) factory;
|
||||
if (theFactory.getFieldName().equals(rex.getColumn())
|
||||
final boolean matches = Aggregations.aggregatorMatches(
|
||||
factory,
|
||||
filter,
|
||||
ApproximateHistogramAggregatorFactory.class,
|
||||
new Predicate<ApproximateHistogramAggregatorFactory>()
|
||||
{
|
||||
@Override
|
||||
public boolean apply(final ApproximateHistogramAggregatorFactory theFactory)
|
||||
{
|
||||
return theFactory.getFieldName().equals(rex.getColumn())
|
||||
&& theFactory.getResolution() == resolution
|
||||
&& theFactory.getNumBuckets() == numBuckets
|
||||
&& theFactory.getLowerLimit() == lowerLimit
|
||||
&& theFactory.getUpperLimit() == upperLimit) {
|
||||
&& theFactory.getUpperLimit() == upperLimit;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
if (matches) {
|
||||
// Found existing one. Use this.
|
||||
return Aggregation.create(
|
||||
ImmutableList.<AggregatorFactory>of(),
|
||||
new QuantilePostAggregator(name, theFactory.getName(), probability)
|
||||
new QuantilePostAggregator(name, factory.getName(), probability)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (rowSignature.getColumnType(rex.getColumn()) == ValueType.COMPLEX) {
|
||||
aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory(
|
||||
|
@ -135,12 +160,13 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
return Aggregation.create(
|
||||
ImmutableList.of(aggregatorFactory),
|
||||
new QuantilePostAggregator(name, histogramName, probability)
|
||||
);
|
||||
).filter(filter);
|
||||
}
|
||||
|
||||
private static class QuantileSqlAggFunction extends SqlAggFunction
|
||||
{
|
||||
private static final String SIGNATURE = "'" + NAME + "(column, probability)'";
|
||||
private static final String SIGNATURE1 = "'" + NAME + "(column, probability)'\n";
|
||||
private static final String SIGNATURE2 = "'" + NAME + "(column, probability, resolution)'\n";
|
||||
|
||||
QuantileSqlAggFunction()
|
||||
{
|
||||
|
@ -150,10 +176,16 @@ public class QuantileSqlAggregator implements SqlAggregator
|
|||
SqlKind.OTHER_FUNCTION,
|
||||
ReturnTypes.explicit(SqlTypeName.DOUBLE),
|
||||
null,
|
||||
OperandTypes.or(
|
||||
OperandTypes.and(
|
||||
OperandTypes.sequence(SIGNATURE, OperandTypes.ANY, OperandTypes.LITERAL),
|
||||
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
|
||||
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
|
||||
),
|
||||
OperandTypes.and(
|
||||
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
|
||||
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
|
||||
)
|
||||
),
|
||||
SqlFunctionCategory.NUMERIC,
|
||||
false,
|
||||
false
|
||||
|
|
|
@ -26,10 +26,22 @@ import io.druid.granularity.QueryGranularities;
|
|||
import io.druid.java.util.common.guava.Sequences;
|
||||
import io.druid.query.Druids;
|
||||
import io.druid.query.aggregation.AggregatorFactory;
|
||||
import io.druid.query.aggregation.CountAggregatorFactory;
|
||||
import io.druid.query.aggregation.DoubleSumAggregatorFactory;
|
||||
import io.druid.query.aggregation.FilteredAggregatorFactory;
|
||||
import io.druid.query.aggregation.PostAggregator;
|
||||
import io.druid.query.aggregation.histogram.ApproximateHistogramAggregatorFactory;
|
||||
import io.druid.query.aggregation.histogram.ApproximateHistogramDruidModule;
|
||||
import io.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory;
|
||||
import io.druid.query.aggregation.histogram.QuantilePostAggregator;
|
||||
import io.druid.query.filter.NotDimFilter;
|
||||
import io.druid.query.filter.SelectorDimFilter;
|
||||
import io.druid.query.spec.MultipleIntervalSegmentSpec;
|
||||
import io.druid.segment.IndexBuilder;
|
||||
import io.druid.segment.QueryableIndex;
|
||||
import io.druid.segment.TestHelper;
|
||||
import io.druid.segment.incremental.IncrementalIndexSchema;
|
||||
import io.druid.sql.calcite.CalciteQueryTest;
|
||||
import io.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import io.druid.sql.calcite.filtration.Filtration;
|
||||
import io.druid.sql.calcite.planner.Calcites;
|
||||
|
@ -40,6 +52,8 @@ import io.druid.sql.calcite.planner.PlannerResult;
|
|||
import io.druid.sql.calcite.util.CalciteTests;
|
||||
import io.druid.sql.calcite.util.QueryLogHook;
|
||||
import io.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
|
||||
import io.druid.timeline.DataSegment;
|
||||
import io.druid.timeline.partition.LinearShardSpec;
|
||||
import org.apache.calcite.schema.SchemaPlus;
|
||||
import org.apache.calcite.tools.Planner;
|
||||
import org.junit.After;
|
||||
|
@ -52,10 +66,10 @@ import org.junit.rules.TemporaryFolder;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static io.druid.sql.calcite.CalciteQueryTest.TIMESERIES_CONTEXT;
|
||||
|
||||
public class QuantileSqlAggregatorTest
|
||||
{
|
||||
private static final String DATA_SOURCE = "foo";
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder temporaryFolder = new TemporaryFolder();
|
||||
|
||||
|
@ -69,7 +83,45 @@ public class QuantileSqlAggregatorTest
|
|||
public void setUp() throws Exception
|
||||
{
|
||||
Calcites.setSystemProperties();
|
||||
walker = CalciteTests.createMockWalker(temporaryFolder.newFolder());
|
||||
|
||||
// Note: this is needed in order to properly register the serde for Histogram.
|
||||
new ApproximateHistogramDruidModule().configure(null);
|
||||
|
||||
final QueryableIndex index = IndexBuilder.create()
|
||||
.tmpDir(temporaryFolder.newFolder())
|
||||
.indexMerger(TestHelper.getTestIndexMergerV9())
|
||||
.schema(
|
||||
new IncrementalIndexSchema.Builder()
|
||||
.withMetrics(
|
||||
new AggregatorFactory[]{
|
||||
new CountAggregatorFactory("cnt"),
|
||||
new DoubleSumAggregatorFactory("m1", "m1"),
|
||||
new ApproximateHistogramAggregatorFactory(
|
||||
"hist_m1",
|
||||
"m1",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
)
|
||||
}
|
||||
)
|
||||
.withRollup(false)
|
||||
.build()
|
||||
)
|
||||
.rows(CalciteTests.ROWS1)
|
||||
.buildMMappedIndex();
|
||||
|
||||
walker = new SpecificSegmentsQuerySegmentWalker(CalciteTests.queryRunnerFactoryConglomerate()).add(
|
||||
DataSegment.builder()
|
||||
.dataSource(DATA_SOURCE)
|
||||
.interval(index.getDataInterval())
|
||||
.version("1")
|
||||
.shardSpec(new LinearShardSpec(0))
|
||||
.build(),
|
||||
index
|
||||
);
|
||||
|
||||
final PlannerConfig plannerConfig = new PlannerConfig();
|
||||
final SchemaPlus rootSchema = Calcites.createRootSchema(
|
||||
CalciteTests.createMockSchema(
|
||||
|
@ -96,13 +148,23 @@ public class QuantileSqlAggregatorTest
|
|||
public void testQuantileOnFloatAndLongs() throws Exception
|
||||
{
|
||||
try (final Planner planner = plannerFactory.createPlanner()) {
|
||||
final String sql = "SELECT QUANTILE(m1, 0.01), QUANTILE(m1, 0.5), QUANTILE(m1, 0.99), QUANTILE(cnt, 0.5) FROM foo";
|
||||
final String sql = "SELECT\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.01),\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.5, 50),\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.98, 200),\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.99),\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.99) FILTER(WHERE dim1 = 'abc'),\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.999) FILTER(WHERE dim1 <> 'abc'),\n"
|
||||
+ "APPROX_QUANTILE(m1, 0.999) FILTER(WHERE dim1 = 'abc'),\n"
|
||||
+ "APPROX_QUANTILE(cnt, 0.5)\n"
|
||||
+ "FROM foo";
|
||||
|
||||
final PlannerResult plannerResult = Calcites.plan(planner, sql);
|
||||
|
||||
// Verify results
|
||||
final List<Object[]> results = Sequences.toList(plannerResult.run(), new ArrayList<Object[]>());
|
||||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{1.0, 3.0, 5.940000057220459, 1.0}
|
||||
new Object[]{1.0, 3.0, 5.880000114440918, 5.940000057220459, 6.0, 4.994999885559082, 6.0, 1.0}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
|
@ -115,17 +177,90 @@ public class QuantileSqlAggregatorTest
|
|||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(QueryGranularities.ALL)
|
||||
.aggregators(ImmutableList.<AggregatorFactory>of(
|
||||
.aggregators(ImmutableList.of(
|
||||
new ApproximateHistogramAggregatorFactory("a0:agg", "m1", null, null, null, null),
|
||||
new ApproximateHistogramAggregatorFactory("a3:agg", "cnt", null, null, null, null)
|
||||
new ApproximateHistogramAggregatorFactory("a2:agg", "m1", 200, null, null, null),
|
||||
new FilteredAggregatorFactory(
|
||||
new ApproximateHistogramAggregatorFactory("a4:agg", "m1", null, null, null, null),
|
||||
new SelectorDimFilter("dim1", "abc", null)
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new ApproximateHistogramAggregatorFactory("a5:agg", "m1", null, null, null, null),
|
||||
new NotDimFilter(new SelectorDimFilter("dim1", "abc", null))
|
||||
),
|
||||
new ApproximateHistogramAggregatorFactory("a7:agg", "cnt", null, null, null, null)
|
||||
))
|
||||
.postAggregators(ImmutableList.<PostAggregator>of(
|
||||
new QuantilePostAggregator("a0", "a0:agg", 0.01f),
|
||||
new QuantilePostAggregator("a1", "a0:agg", 0.50f),
|
||||
new QuantilePostAggregator("a2", "a0:agg", 0.99f),
|
||||
new QuantilePostAggregator("a3", "a3:agg", 0.50f)
|
||||
new QuantilePostAggregator("a2", "a2:agg", 0.98f),
|
||||
new QuantilePostAggregator("a3", "a0:agg", 0.99f),
|
||||
new QuantilePostAggregator("a4", "a4:agg", 0.99f),
|
||||
new QuantilePostAggregator("a5", "a5:agg", 0.999f),
|
||||
new QuantilePostAggregator("a6", "a4:agg", 0.999f),
|
||||
new QuantilePostAggregator("a7", "a7:agg", 0.50f)
|
||||
))
|
||||
.context(TIMESERIES_CONTEXT)
|
||||
.context(CalciteQueryTest.TIMESERIES_CONTEXT)
|
||||
.build(),
|
||||
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testQuantileOnComplexColumn() throws Exception
|
||||
{
|
||||
try (final Planner planner = plannerFactory.createPlanner()) {
|
||||
final String sql = "SELECT\n"
|
||||
+ "APPROX_QUANTILE(hist_m1, 0.01),\n"
|
||||
+ "APPROX_QUANTILE(hist_m1, 0.5, 50),\n"
|
||||
+ "APPROX_QUANTILE(hist_m1, 0.98, 200),\n"
|
||||
+ "APPROX_QUANTILE(hist_m1, 0.99),\n"
|
||||
+ "APPROX_QUANTILE(hist_m1, 0.99) FILTER(WHERE dim1 = 'abc'),\n"
|
||||
+ "APPROX_QUANTILE(hist_m1, 0.999) FILTER(WHERE dim1 <> 'abc'),\n"
|
||||
+ "APPROX_QUANTILE(hist_m1, 0.999) FILTER(WHERE dim1 = 'abc')\n"
|
||||
+ "FROM foo";
|
||||
|
||||
final PlannerResult plannerResult = Calcites.plan(planner, sql);
|
||||
|
||||
// Verify results
|
||||
final List<Object[]> results = Sequences.toList(plannerResult.run(), new ArrayList<Object[]>());
|
||||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{1.0, 3.0, 5.880000114440918, 5.940000057220459, 6.0, 4.994999885559082, 6.0}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
|
||||
// Verify query
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(CalciteTests.DATASOURCE1)
|
||||
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.granularity(QueryGranularities.ALL)
|
||||
.aggregators(ImmutableList.of(
|
||||
new ApproximateHistogramFoldingAggregatorFactory("a0:agg", "hist_m1", null, null, null, null),
|
||||
new ApproximateHistogramFoldingAggregatorFactory("a2:agg", "hist_m1", 200, null, null, null),
|
||||
new FilteredAggregatorFactory(
|
||||
new ApproximateHistogramFoldingAggregatorFactory("a4:agg", "hist_m1", null, null, null, null),
|
||||
new SelectorDimFilter("dim1", "abc", null)
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
new ApproximateHistogramFoldingAggregatorFactory("a5:agg", "hist_m1", null, null, null, null),
|
||||
new NotDimFilter(new SelectorDimFilter("dim1", "abc", null))
|
||||
)
|
||||
))
|
||||
.postAggregators(ImmutableList.<PostAggregator>of(
|
||||
new QuantilePostAggregator("a0", "a0:agg", 0.01f),
|
||||
new QuantilePostAggregator("a1", "a0:agg", 0.50f),
|
||||
new QuantilePostAggregator("a2", "a2:agg", 0.98f),
|
||||
new QuantilePostAggregator("a3", "a0:agg", 0.99f),
|
||||
new QuantilePostAggregator("a4", "a4:agg", 0.99f),
|
||||
new QuantilePostAggregator("a5", "a5:agg", 0.999f),
|
||||
new QuantilePostAggregator("a6", "a4:agg", 0.999f)
|
||||
))
|
||||
.context(CalciteQueryTest.TIMESERIES_CONTEXT)
|
||||
.build(),
|
||||
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
|
||||
);
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. Metamarkets licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package io.druid.sql.calcite.aggregation;
|
||||
|
||||
import com.google.common.base.Predicate;
|
||||
import io.druid.query.aggregation.AggregatorFactory;
|
||||
import io.druid.query.aggregation.FilteredAggregatorFactory;
|
||||
import io.druid.query.filter.DimFilter;
|
||||
|
||||
public class Aggregations
|
||||
{
|
||||
private Aggregations()
|
||||
{
|
||||
// No instantiation.
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if "factory" is an aggregator factory that either matches "predicate" (if filter is null) or is
|
||||
* a filtered aggregator factory whose filter is equal to "filter" and underlying aggregator matches "predicate".
|
||||
*
|
||||
* @param factory factory to match
|
||||
* @param filter filter, may be null
|
||||
* @param clazz class of factory to match
|
||||
* @param predicate predicate
|
||||
*
|
||||
* @return true if the aggregator matches filter + predicate
|
||||
*/
|
||||
public static <T extends AggregatorFactory> boolean aggregatorMatches(
|
||||
final AggregatorFactory factory,
|
||||
final DimFilter filter,
|
||||
final Class<T> clazz,
|
||||
final Predicate<T> predicate
|
||||
)
|
||||
{
|
||||
if (filter != null) {
|
||||
return factory instanceof FilteredAggregatorFactory &&
|
||||
((FilteredAggregatorFactory) factory).getFilter().equals(filter)
|
||||
&& aggregatorMatches(((FilteredAggregatorFactory) factory).getAggregator(), null, clazz, predicate);
|
||||
} else {
|
||||
return clazz.isAssignableFrom(factory.getClass()) && predicate.apply(clazz.cast(factory));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -27,6 +27,7 @@ import io.druid.query.aggregation.cardinality.CardinalityAggregatorFactory;
|
|||
import io.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
|
||||
import io.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
|
||||
import io.druid.query.dimension.DimensionSpec;
|
||||
import io.druid.query.filter.DimFilter;
|
||||
import io.druid.segment.column.ValueType;
|
||||
import io.druid.sql.calcite.expression.Expressions;
|
||||
import io.druid.sql.calcite.expression.RowExtraction;
|
||||
|
@ -60,7 +61,8 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
|
|||
final RowSignature rowSignature,
|
||||
final List<Aggregation> existingAggregations,
|
||||
final Project project,
|
||||
final AggregateCall aggregateCall
|
||||
final AggregateCall aggregateCall,
|
||||
final DimFilter filter
|
||||
)
|
||||
{
|
||||
final RowExtraction rex = Expressions.toRowExtraction(
|
||||
|
@ -99,7 +101,7 @@ public class ApproxCountDistinctSqlAggregator implements SqlAggregator
|
|||
return new HyperUniqueFinalizingPostAggregator(outputName, name);
|
||||
}
|
||||
}
|
||||
);
|
||||
).filter(filter);
|
||||
}
|
||||
|
||||
private static class ApproxCountDistinctSqlAggFunction extends SqlAggFunction
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
package io.druid.sql.calcite.aggregation;
|
||||
|
||||
import io.druid.query.filter.DimFilter;
|
||||
import io.druid.sql.calcite.table.RowSignature;
|
||||
import org.apache.calcite.rel.core.AggregateCall;
|
||||
import org.apache.calcite.rel.core.Project;
|
||||
|
@ -44,9 +45,11 @@ public interface SqlAggregator
|
|||
*
|
||||
* @param name desired output name of the aggregation
|
||||
* @param rowSignature signature of the rows being aggregated
|
||||
* @param existingAggregations existing aggregations for this query; useful for re-using aggregators
|
||||
* @param project SQL projection to apply before the aggregate call
|
||||
* @param existingAggregations existing aggregations for this query; useful for re-using aggregations. May be safely
|
||||
* ignored if you do not want to re-use existing aggregations.
|
||||
* @param project SQL projection to apply before the aggregate call, may be null
|
||||
* @param aggregateCall SQL aggregate call
|
||||
* @param filter filter that should be applied to the aggregation, may be null
|
||||
*
|
||||
* @return aggregation, or null if the call cannot be translated
|
||||
*/
|
||||
|
@ -56,6 +59,7 @@ public interface SqlAggregator
|
|||
final RowSignature rowSignature,
|
||||
final List<Aggregation> existingAggregations,
|
||||
final Project project,
|
||||
final AggregateCall aggregateCall
|
||||
final AggregateCall aggregateCall,
|
||||
final DimFilter filter
|
||||
);
|
||||
}
|
||||
|
|
|
@ -739,7 +739,6 @@ public class GroupByRules
|
|||
final String name = aggOutputName(aggNumber);
|
||||
final SqlKind kind = call.getAggregation().getKind();
|
||||
final SqlTypeName outputType = call.getType().getSqlTypeName();
|
||||
final Aggregation retVal;
|
||||
|
||||
if (call.filterArg >= 0) {
|
||||
// AGG(xxx) FILTER(WHERE yyy)
|
||||
|
@ -759,15 +758,16 @@ public class GroupByRules
|
|||
|
||||
if (kind == SqlKind.COUNT && call.getArgList().isEmpty()) {
|
||||
// COUNT(*)
|
||||
retVal = Aggregation.create(new CountAggregatorFactory(name));
|
||||
return Aggregation.create(new CountAggregatorFactory(name)).filter(makeFilter(filters, sourceRowSignature));
|
||||
} else if (kind == SqlKind.COUNT && call.isDistinct()) {
|
||||
// COUNT(DISTINCT x)
|
||||
retVal = approximateCountDistinct ? APPROX_COUNT_DISTINCT.toDruidAggregation(
|
||||
return approximateCountDistinct ? APPROX_COUNT_DISTINCT.toDruidAggregation(
|
||||
name,
|
||||
sourceRowSignature,
|
||||
existingAggregations,
|
||||
project,
|
||||
call
|
||||
call,
|
||||
makeFilter(filters, sourceRowSignature)
|
||||
) : null;
|
||||
} else if (kind == SqlKind.COUNT
|
||||
|| kind == SqlKind.SUM
|
||||
|
@ -845,9 +845,10 @@ public class GroupByRules
|
|||
|
||||
if (forceCount || kind == SqlKind.COUNT) {
|
||||
// COUNT(x)
|
||||
retVal = Aggregation.create(new CountAggregatorFactory(name));
|
||||
return Aggregation.create(new CountAggregatorFactory(name)).filter(makeFilter(filters, sourceRowSignature));
|
||||
} else {
|
||||
// Built-in aggregator that is not COUNT.
|
||||
final Aggregation retVal;
|
||||
final String fieldName = input.getFieldName();
|
||||
final String expression = input.getExpression();
|
||||
|
||||
|
@ -888,26 +889,21 @@ public class GroupByRules
|
|||
// Not reached.
|
||||
throw new ISE("WTF?! Kind[%s] got into the built-in aggregator path somehow?!", kind);
|
||||
}
|
||||
|
||||
return retVal.filter(makeFilter(filters, sourceRowSignature));
|
||||
}
|
||||
} else {
|
||||
// Not a built-in aggregator, check operator table.
|
||||
final SqlAggregator sqlAggregator = operatorTable.lookupAggregator(call.getAggregation().getName());
|
||||
retVal = sqlAggregator != null ? sqlAggregator.toDruidAggregation(
|
||||
return sqlAggregator != null ? sqlAggregator.toDruidAggregation(
|
||||
name,
|
||||
sourceRowSignature,
|
||||
existingAggregations,
|
||||
project,
|
||||
call
|
||||
call,
|
||||
makeFilter(filters, sourceRowSignature)
|
||||
) : null;
|
||||
}
|
||||
|
||||
final DimFilter filter = filters.isEmpty()
|
||||
? null
|
||||
: Filtration.create(new AndDimFilter(filters))
|
||||
.optimizeFilterOnly(sourceRowSignature)
|
||||
.getDimFilter();
|
||||
|
||||
return retVal != null ? retVal.filter(filter) : null;
|
||||
}
|
||||
|
||||
public static String dimOutputName(final int dimNumber)
|
||||
|
@ -924,4 +920,13 @@ public class GroupByRules
|
|||
{
|
||||
return "A" + aggNumber + ":" + key;
|
||||
}
|
||||
|
||||
private static DimFilter makeFilter(final List<DimFilter> filters, final RowSignature sourceRowSignature)
|
||||
{
|
||||
return filters.isEmpty()
|
||||
? null
|
||||
: Filtration.create(new AndDimFilter(filters))
|
||||
.optimizeFilterOnly(sourceRowSignature)
|
||||
.getDimFilter();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ import io.druid.server.coordination.DruidServerMetadata;
|
|||
import io.druid.sql.calcite.planner.PlannerConfig;
|
||||
import io.druid.sql.calcite.rel.QueryMaker;
|
||||
import io.druid.sql.calcite.table.DruidTable;
|
||||
import io.druid.sql.calcite.table.RowSignature;
|
||||
import io.druid.timeline.DataSegment;
|
||||
import org.apache.calcite.schema.Table;
|
||||
import org.apache.calcite.schema.impl.AbstractSchema;
|
||||
|
@ -296,7 +297,7 @@ public class DruidSchema extends AbstractSchema
|
|||
}
|
||||
|
||||
final Map<String, ColumnAnalysis> columnMetadata = Iterables.getOnlyElement(results).getColumns();
|
||||
final Map<String, ValueType> columnValueTypes = Maps.newHashMap();
|
||||
final RowSignature.Builder rowSignature = RowSignature.builder();
|
||||
|
||||
for (Map.Entry<String, ColumnAnalysis> entry : columnMetadata.entrySet()) {
|
||||
if (entry.getValue().isError()) {
|
||||
|
@ -314,13 +315,13 @@ public class DruidSchema extends AbstractSchema
|
|||
valueType = ValueType.COMPLEX;
|
||||
}
|
||||
|
||||
columnValueTypes.put(entry.getKey(), valueType);
|
||||
rowSignature.add(entry.getKey(), valueType);
|
||||
}
|
||||
|
||||
return new DruidTable(
|
||||
queryMaker,
|
||||
new TableDataSource(dataSource),
|
||||
columnValueTypes
|
||||
rowSignature.build()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,9 +20,7 @@
|
|||
package io.druid.sql.calcite.table;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableSortedMap;
|
||||
import io.druid.query.DataSource;
|
||||
import io.druid.segment.column.ValueType;
|
||||
import io.druid.sql.calcite.rel.DruidQueryRel;
|
||||
import io.druid.sql.calcite.rel.QueryMaker;
|
||||
import org.apache.calcite.plan.RelOptCluster;
|
||||
|
@ -35,8 +33,6 @@ import org.apache.calcite.schema.Statistic;
|
|||
import org.apache.calcite.schema.Statistics;
|
||||
import org.apache.calcite.schema.TranslatableTable;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class DruidTable implements TranslatableTable
|
||||
{
|
||||
private final QueryMaker queryMaker;
|
||||
|
@ -46,17 +42,12 @@ public class DruidTable implements TranslatableTable
|
|||
public DruidTable(
|
||||
final QueryMaker queryMaker,
|
||||
final DataSource dataSource,
|
||||
final Map<String, ValueType> columns
|
||||
final RowSignature rowSignature
|
||||
)
|
||||
{
|
||||
this.queryMaker = Preconditions.checkNotNull(queryMaker, "queryMaker");
|
||||
this.dataSource = Preconditions.checkNotNull(dataSource, "dataSource");
|
||||
|
||||
final RowSignature.Builder rowSignatureBuilder = RowSignature.builder();
|
||||
for (Map.Entry<String, ValueType> entry : ImmutableSortedMap.copyOf(columns).entrySet()) {
|
||||
rowSignatureBuilder.add(entry.getKey(), entry.getValue());
|
||||
}
|
||||
this.rowSignature = rowSignatureBuilder.build();
|
||||
this.rowSignature = Preconditions.checkNotNull(rowSignature, "rowSignature");
|
||||
}
|
||||
|
||||
public QueryMaker getQueryMaker()
|
||||
|
|
|
@ -51,7 +51,7 @@ public class RowSignature
|
|||
private final Map<String, ValueType> columnTypes;
|
||||
private final List<String> columnNames;
|
||||
|
||||
public RowSignature(final List<Pair<String, ValueType>> columnTypeList)
|
||||
private RowSignature(final List<Pair<String, ValueType>> columnTypeList)
|
||||
{
|
||||
final Map<String, ValueType> columnTypes0 = Maps.newHashMap();
|
||||
final ImmutableList.Builder<String> columnNamesBuilder = ImmutableList.builder();
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
package io.druid.sql.calcite.util;
|
||||
|
||||
import com.google.common.base.Supplier;
|
||||
import com.google.common.base.Throwables;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
@ -36,8 +37,6 @@ import io.druid.query.Query;
|
|||
import io.druid.query.QueryRunnerFactory;
|
||||
import io.druid.query.QueryRunnerFactoryConglomerate;
|
||||
import io.druid.query.QueryRunnerTestHelper;
|
||||
import io.druid.query.QuerySegmentWalker;
|
||||
import io.druid.query.TableDataSource;
|
||||
import io.druid.query.aggregation.AggregatorFactory;
|
||||
import io.druid.query.aggregation.CountAggregatorFactory;
|
||||
import io.druid.query.aggregation.DoubleSumAggregatorFactory;
|
||||
|
@ -65,19 +64,14 @@ import io.druid.query.topn.TopNQueryRunnerFactory;
|
|||
import io.druid.segment.IndexBuilder;
|
||||
import io.druid.segment.QueryableIndex;
|
||||
import io.druid.segment.TestHelper;
|
||||
import io.druid.segment.column.ValueType;
|
||||
import io.druid.segment.incremental.IncrementalIndexSchema;
|
||||
import io.druid.sql.calcite.aggregation.ApproxCountDistinctSqlAggregator;
|
||||
import io.druid.sql.calcite.aggregation.SqlAggregator;
|
||||
import io.druid.sql.calcite.planner.DruidOperatorTable;
|
||||
import io.druid.sql.calcite.planner.PlannerConfig;
|
||||
import io.druid.sql.calcite.rel.QueryMaker;
|
||||
import io.druid.sql.calcite.table.DruidTable;
|
||||
import io.druid.sql.calcite.schema.DruidSchema;
|
||||
import io.druid.timeline.DataSegment;
|
||||
import io.druid.timeline.partition.LinearShardSpec;
|
||||
import org.apache.calcite.schema.Schema;
|
||||
import org.apache.calcite.schema.Table;
|
||||
import org.apache.calcite.schema.impl.AbstractSchema;
|
||||
import org.joda.time.DateTime;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -208,7 +202,7 @@ public class CalciteTests
|
|||
.withRollup(false)
|
||||
.build();
|
||||
|
||||
private static final List<InputRow> ROWS1 = ImmutableList.of(
|
||||
public static final List<InputRow> ROWS1 = ImmutableList.of(
|
||||
createRow(ImmutableMap.of("t", "2000-01-01", "m1", "1.0", "dim1", "", "dim2", ImmutableList.of("a"))),
|
||||
createRow(ImmutableMap.of("t", "2000-01-02", "m1", "2.0", "dim1", "10.1", "dim2", ImmutableList.of())),
|
||||
createRow(ImmutableMap.of("t", "2000-01-03", "m1", "3.0", "dim1", "2", "dim2", ImmutableList.of(""))),
|
||||
|
@ -217,21 +211,12 @@ public class CalciteTests
|
|||
createRow(ImmutableMap.of("t", "2001-01-03", "m1", "6.0", "dim1", "abc"))
|
||||
);
|
||||
|
||||
private static final List<InputRow> ROWS2 = ImmutableList.of(
|
||||
public static final List<InputRow> ROWS2 = ImmutableList.of(
|
||||
createRow("2000-01-01", "דרואיד", "he", 1.0),
|
||||
createRow("2000-01-01", "druid", "en", 1.0),
|
||||
createRow("2000-01-01", "друид", "ru", 1.0)
|
||||
);
|
||||
|
||||
private static final Map<String, ValueType> COLUMN_TYPES = ImmutableMap.<String, ValueType>builder()
|
||||
.put("__time", ValueType.LONG)
|
||||
.put("cnt", ValueType.LONG)
|
||||
.put("dim1", ValueType.STRING)
|
||||
.put("dim2", ValueType.STRING)
|
||||
.put("m1", ValueType.FLOAT)
|
||||
.put("unique_dim1", ValueType.COMPLEX)
|
||||
.build();
|
||||
|
||||
private CalciteTests()
|
||||
{
|
||||
// No instantiation.
|
||||
|
@ -282,23 +267,27 @@ public class CalciteTests
|
|||
return new DruidOperatorTable(ImmutableSet.<SqlAggregator>of(new ApproxCountDistinctSqlAggregator()));
|
||||
}
|
||||
|
||||
public static Schema createMockSchema(final QuerySegmentWalker walker, final PlannerConfig plannerConfig)
|
||||
public static DruidSchema createMockSchema(
|
||||
final SpecificSegmentsQuerySegmentWalker walker,
|
||||
final PlannerConfig plannerConfig
|
||||
)
|
||||
{
|
||||
final QueryMaker queryMaker = new QueryMaker(walker, plannerConfig);
|
||||
final DruidTable druidTable1 = new DruidTable(queryMaker, new TableDataSource(DATASOURCE1), COLUMN_TYPES);
|
||||
final DruidTable druidTable2 = new DruidTable(queryMaker, new TableDataSource(DATASOURCE2), COLUMN_TYPES);
|
||||
final Map<String, Table> tableMap = ImmutableMap.<String, Table>of(
|
||||
DATASOURCE1, druidTable1,
|
||||
DATASOURCE2, druidTable2
|
||||
final DruidSchema schema = new DruidSchema(
|
||||
walker,
|
||||
new TestServerInventoryView(walker.getSegments()),
|
||||
plannerConfig
|
||||
);
|
||||
return new AbstractSchema()
|
||||
{
|
||||
@Override
|
||||
protected Map<String, Table> getTableMap()
|
||||
{
|
||||
return tableMap;
|
||||
|
||||
schema.start();
|
||||
try {
|
||||
schema.awaitInitialization();
|
||||
}
|
||||
};
|
||||
catch (InterruptedException e) {
|
||||
throw Throwables.propagate(e);
|
||||
}
|
||||
|
||||
schema.stop();
|
||||
return schema;
|
||||
}
|
||||
|
||||
public static InputRow createRow(final ImmutableMap<String, ?> map)
|
||||
|
|
Loading…
Reference in New Issue