diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchModule.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchModule.java index 2a07c70ec63..68da26a94b4 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchModule.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchModule.java @@ -25,7 +25,9 @@ import com.fasterxml.jackson.databind.module.SimpleModule; import com.google.common.annotations.VisibleForTesting; import com.google.inject.Binder; import org.apache.datasketches.hll.HllSketch; +import org.apache.druid.guice.ExpressionModule; import org.apache.druid.initialization.DruidModule; +import org.apache.druid.query.aggregation.datasketches.hll.sql.HllPostAggExprMacros; import org.apache.druid.query.aggregation.datasketches.hll.sql.HllSketchApproxCountDistinctSqlAggregator; import org.apache.druid.query.aggregation.datasketches.hll.sql.HllSketchEstimateOperatorConversion; import org.apache.druid.query.aggregation.datasketches.hll.sql.HllSketchEstimateWithErrorBoundsOperatorConversion; @@ -64,6 +66,7 @@ public class HllSketchModule implements DruidModule SqlBindings.addOperatorConversion(binder, HllSketchSetUnionOperatorConversion.class); SqlBindings.addOperatorConversion(binder, HllSketchToStringOperatorConversion.class); + ExpressionModule.addExprMacro(binder, HllPostAggExprMacros.HLLSketchEstimateExprMacro.class); SqlBindings.addApproxCountDistinctChoice( binder, HllSketchApproxCountDistinctSqlAggregator.NAME, diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllPostAggExprMacros.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllPostAggExprMacros.java new file mode 100644 index 00000000000..8e310d19d13 --- /dev/null +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllPostAggExprMacros.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.datasketches.hll.sql; + +import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExpressionType; +import org.apache.druid.query.aggregation.datasketches.hll.HllSketchHolder; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.stream.Collectors; + +public class HllPostAggExprMacros +{ + public static final String HLL_SKETCH_ESTIMATE = "hll_sketch_estimate"; + + public static class HLLSketchEstimateExprMacro implements ExprMacroTable.ExprMacro + { + + @Override + public Expr apply(List args) + { + validationHelperCheckAnyOfArgumentCount(args, 1, 2); + return new HllSketchEstimateExpr(args); + } + + @Override + public String name() + { + return HLL_SKETCH_ESTIMATE; + } + } + + public static class HllSketchEstimateExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr + { + private Expr estimateExpr; + private Expr isRound; + + public HllSketchEstimateExpr(List args) + { + super(HLL_SKETCH_ESTIMATE, args); + this.estimateExpr = args.get(0); + if (args.size() == 2) { + isRound = args.get(1); + } + } + + @Nullable + @Override + public ExpressionType getOutputType(InputBindingInspector inspector) + { + return ExpressionType.DOUBLE; + } + + @Override + public ExprEval eval(ObjectBinding bindings) + { + boolean round = false; + ExprEval eval = estimateExpr.eval(bindings); + if (isRound != null) { + round = isRound.eval(bindings).asBoolean(); + } + + final Object valObj = eval.value(); + if (valObj == null) { + return ExprEval.of(null); + } + HllSketchHolder h = HllSketchHolder.fromObj(valObj); + double estimate = h.getEstimate(); + return round ? ExprEval.of(Math.round(estimate)) : ExprEval.of(estimate); + } + + @Override + public Expr visit(Shuttle shuttle) + { + List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); + return shuttle.visit(new HllSketchEstimateExpr(newArgs)); + } + } +} + diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchEstimateOperatorConversion.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchEstimateOperatorConversion.java index 63980972150..fa154933531 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchEstimateOperatorConversion.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchEstimateOperatorConversion.java @@ -43,7 +43,7 @@ import java.util.List; public class HllSketchEstimateOperatorConversion implements SqlOperatorConversion { - private static final String FUNCTION_NAME = "HLL_SKETCH_ESTIMATE"; + private static final String FUNCTION_NAME = "hll_sketch_estimate"; private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder(StringUtils.toUpperCase(FUNCTION_NAME)) .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.BOOLEAN) @@ -64,7 +64,12 @@ public class HllSketchEstimateOperatorConversion implements SqlOperatorConversio RexNode rexNode ) { - return null; + return OperatorConversions.convertDirectCall( + plannerContext, + rowSignature, + rexNode, + FUNCTION_NAME + ); } @Nullable diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchModule.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchModule.java index 9d5746533ff..979f3f2579f 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchModule.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchModule.java @@ -24,7 +24,9 @@ import com.fasterxml.jackson.databind.jsontype.NamedType; import com.fasterxml.jackson.databind.module.SimpleModule; import com.google.common.annotations.VisibleForTesting; import com.google.inject.Binder; +import org.apache.druid.guice.ExpressionModule; import org.apache.druid.initialization.DruidModule; +import org.apache.druid.query.aggregation.datasketches.theta.sql.ThetaPostAggMacros; import org.apache.druid.query.aggregation.datasketches.theta.sql.ThetaSketchApproxCountDistinctSqlAggregator; import org.apache.druid.query.aggregation.datasketches.theta.sql.ThetaSketchEstimateOperatorConversion; import org.apache.druid.query.aggregation.datasketches.theta.sql.ThetaSketchEstimateWithErrorBoundsOperatorConversion; @@ -71,6 +73,7 @@ public class SketchModule implements DruidModule ThetaSketchApproxCountDistinctSqlAggregator.NAME, ThetaSketchApproxCountDistinctSqlAggregator.class ); + ExpressionModule.addExprMacro(binder, ThetaPostAggMacros.ThetaSketchEstimateExprMacro.class); } @Override diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaPostAggMacros.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaPostAggMacros.java new file mode 100644 index 00000000000..762da9ae476 --- /dev/null +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaPostAggMacros.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.datasketches.theta.sql; + +import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExpressionType; +import org.apache.druid.query.aggregation.datasketches.theta.SketchHolder; + +import javax.annotation.Nullable; +import java.util.List; + +public class ThetaPostAggMacros +{ + public static final String THETA_SKETCH_ESTIMATE = "theta_sketch_estimate"; + + public static class ThetaSketchEstimateExprMacro implements ExprMacroTable.ExprMacro + { + + @Override + public Expr apply(List args) + { + validationHelperCheckArgumentCount(args, 1); + return new ThetaSketchEstimateExpr(args.get(0)); + } + + @Override + public String name() + { + return THETA_SKETCH_ESTIMATE; + } + } + + public static class ThetaSketchEstimateExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr + { + private Expr estimateExpr; + + public ThetaSketchEstimateExpr(Expr arg) + { + super(THETA_SKETCH_ESTIMATE, arg); + this.estimateExpr = arg; + } + + @Override + public ExprEval eval(ObjectBinding bindings) + { + ExprEval eval = estimateExpr.eval(bindings); + final Object valObj = eval.value(); + if (valObj == null) { + return ExprEval.of(null); + } + if (valObj instanceof SketchHolder) { + SketchHolder thetaSketchHolder = (SketchHolder) valObj; + double estimate = thetaSketchHolder.getEstimate(); + return ExprEval.of(estimate); + } else { + throw new IllegalArgumentException("requires a ThetaSketch as the argument"); + } + } + + @Override + public Expr visit(Shuttle shuttle) + { + return shuttle.visit(new ThetaSketchEstimateExpr(arg)); + } + + @Nullable + @Override + public ExpressionType getOutputType(InputBindingInspector inspector) + { + return ExpressionType.DOUBLE; + } + } +} diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchEstimateOperatorConversion.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchEstimateOperatorConversion.java index 48d19e5c5f1..cb0d93efd02 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchEstimateOperatorConversion.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchEstimateOperatorConversion.java @@ -40,7 +40,7 @@ import java.util.List; public class ThetaSketchEstimateOperatorConversion implements SqlOperatorConversion { - private static final String FUNCTION_NAME = "THETA_SKETCH_ESTIMATE"; + private static final String FUNCTION_NAME = "theta_sketch_estimate"; private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder(StringUtils.toUpperCase(FUNCTION_NAME)) .operandTypes(SqlTypeFamily.ANY) @@ -60,7 +60,7 @@ public class ThetaSketchEstimateOperatorConversion implements SqlOperatorConvers RexNode rexNode ) { - return null; + return OperatorConversions.convertDirectCall(plannerContext, rowSignature, rexNode, FUNCTION_NAME); } @Nullable diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index aa0e9afd13a..9d7c4b8c03a 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -27,6 +27,7 @@ import org.apache.druid.guice.DruidInjectorBuilder; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.PeriodGranularity; +import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.BaseQuery; import org.apache.druid.query.Druids; import org.apache.druid.query.QueryDataSource; @@ -49,9 +50,15 @@ import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregat import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.groupby.GroupByQuery; +import org.apache.druid.query.groupby.orderby.DefaultLimitSpec; +import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.query.timeseries.TimeseriesQuery; +import org.apache.druid.query.topn.InvertedTopNMetricSpec; +import org.apache.druid.query.topn.NumericTopNMetricSpec; +import org.apache.druid.query.topn.TopNQueryBuilder; import org.apache.druid.segment.IndexBuilder; import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.column.ColumnType; @@ -69,6 +76,7 @@ import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.LinearShardSpec; import org.joda.time.DateTimeZone; import org.joda.time.Period; +import org.junit.Assert; import org.junit.Test; import java.io.IOException; @@ -81,6 +89,12 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest { private static final boolean ROUND = true; + private static final ExprMacroTable MACRO_TABLE = new ExprMacroTable( + ImmutableList.of( + new HllPostAggExprMacros.HLLSketchEstimateExprMacro() + ) + ); + @Override public void gatherProperties(Properties properties) { @@ -121,6 +135,14 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest null, false, ROUND + ), + new HllSketchBuildAggregatorFactory( + "hllsketch_dim3", + "dim3", + null, + null, + false, + false ) ) .withRollup(false) @@ -129,6 +151,7 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest .rows(TestDataBuilder.ROWS1) .buildMMappedIndex(); + return new SpecificSegmentsQuerySegmentWalker(conglomerate).add( DataSegment.builder() .dataSource(CalciteTests.DATASOURCE1) @@ -888,4 +911,193 @@ public class HllSketchSqlAggregatorTest extends BaseCalciteQueryTest ImmutableList.of(new Object[]{"a", 0L, "0"}) ); } + + @Test + public void testHllEstimateAsVirtualColumn() + { + testQuery( + "SELECT" + + " HLL_SKETCH_ESTIMATE(hllsketch_dim1)" + + " FROM druid.foo", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(new ExpressionVirtualColumn( + "v0", + "hll_sketch_estimate(\"hllsketch_dim1\")", + ColumnType.DOUBLE, + MACRO_TABLE + )) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .columns("v0") + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{0.0D}, + new Object[]{1.0D}, + new Object[]{1.0D}, + new Object[]{1.0D}, + new Object[]{1.0D}, + new Object[]{1.0D} + ) + ); + } + + @Test + public void testHllEstimateAsVirtualColumnWithRound() + { + testQuery( + "SELECT" + + " HLL_SKETCH_ESTIMATE(hllsketch_dim3, FALSE), HLL_SKETCH_ESTIMATE(hllsketch_dim3, TRUE)" + + " FROM druid.foo", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(new ExpressionVirtualColumn( + "v0", + "hll_sketch_estimate(\"hllsketch_dim3\",0)", + ColumnType.DOUBLE, + MACRO_TABLE + ), new ExpressionVirtualColumn( + "v1", + "hll_sketch_estimate(\"hllsketch_dim3\",1)", + ColumnType.DOUBLE, + MACRO_TABLE + )) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .columns("v0", "v1") + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{2.000000004967054D, 2.0D}, + new Object[]{2.000000004967054D, 2.0D}, + new Object[]{1.0D, 1.0D}, + new Object[]{0.0D, 0.0D}, + new Object[]{0.0D, 0.0D}, + new Object[]{0.0D, 0.0D} + ) + ); + } + + @Test + public void testHllEstimateAsVirtualColumnOnNonHllCol() + { + try { + testQuery( + "SELECT" + + " HLL_SKETCH_ESTIMATE(dim2)" + + " FROM druid.foo", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(new ExpressionVirtualColumn( + "v0", + "hll_sketch_estimate(\"dim2\")", + ColumnType.DOUBLE, + MACRO_TABLE + )) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .columns("v0") + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of() + ); + } + catch (IllegalArgumentException e) { + Assert.assertTrue( + e.getMessage().contains("Input byte[] should at least have 2 bytes for base64 bytes") + ); + } + } + + @Test + public void testHllEstimateAsVirtualColumnWithGroupByOrderBy() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "SELECT" + + " HLL_SKETCH_ESTIMATE(hllsketch_dim1), count(*)" + + " FROM druid.foo" + + " GROUP BY 1" + + " ORDER BY 2 DESC", + ImmutableList.of( + GroupByQuery.builder() + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDataSource(CalciteTests.DATASOURCE1) + .setGranularity(Granularities.ALL) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "hll_sketch_estimate(\"hllsketch_dim1\")", + ColumnType.DOUBLE, + MACRO_TABLE + )) + .setDimensions( + new DefaultDimensionSpec("v0", "d0", ColumnType.DOUBLE)) + .setAggregatorSpecs( + aggregators( + new CountAggregatorFactory("a0") + ) + ) + .setLimitSpec( + DefaultLimitSpec + .builder() + .orderBy( + new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + ) + ) + .build() + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{1.0D, 5L}, + new Object[]{0.0D, 1L} + ) + ); + } + + @Test + public void testHllEstimateAsVirtualColumnWithTopN() + { + testQuery( + "SELECT" + + " HLL_SKETCH_ESTIMATE(hllsketch_dim1), COUNT(*)" + + " FROM druid.foo" + + " GROUP BY 1 ORDER BY 2" + + " LIMIT 2", + ImmutableList.of( + new TopNQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .dimension(new DefaultDimensionSpec("v0", "d0", ColumnType.DOUBLE)) + .virtualColumns(new ExpressionVirtualColumn( + "v0", + "hll_sketch_estimate(\"hllsketch_dim1\")", + ColumnType.DOUBLE, + MACRO_TABLE + )) + .metric(new InvertedTopNMetricSpec(new NumericTopNMetricSpec("a0"))) + .threshold(2) + .aggregators(new CountAggregatorFactory("a0")) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{0.0D, 1L}, + new Object[]{1.0D, 5L} + ) + ); + } } diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java index f5a23178dd0..0887cc4e69a 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java @@ -27,6 +27,7 @@ import org.apache.druid.guice.DruidInjectorBuilder; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.PeriodGranularity; +import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.Druids; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryRunnerFactoryConglomerate; @@ -44,8 +45,13 @@ import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.groupby.GroupByQuery; +import org.apache.druid.query.groupby.orderby.DefaultLimitSpec; +import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; +import org.apache.druid.query.topn.DimensionTopNMetricSpec; +import org.apache.druid.query.topn.TopNQueryBuilder; import org.apache.druid.segment.IndexBuilder; import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.column.ColumnType; @@ -63,6 +69,7 @@ import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.LinearShardSpec; import org.joda.time.DateTimeZone; import org.joda.time.Period; +import org.junit.Assert; import org.junit.Test; import java.io.IOException; @@ -74,6 +81,11 @@ import java.util.Properties; public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest { private static final String DATA_SOURCE = "foo"; + private static final ExprMacroTable MACRO_TABLE = new ExprMacroTable( + ImmutableList.of( + new ThetaPostAggMacros.ThetaSketchEstimateExprMacro() + ) + ); @Override public void gatherProperties(Properties properties) @@ -81,7 +93,10 @@ public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest super.gatherProperties(properties); // Use APPROX_COUNT_DISTINCT_DS_THETA as APPROX_COUNT_DISTINCT impl for these tests. - properties.put(SqlModule.PROPERTY_SQL_APPROX_COUNT_DISTINCT_CHOICE, ThetaSketchApproxCountDistinctSqlAggregator.NAME); + properties.put( + SqlModule.PROPERTY_SQL_APPROX_COUNT_DISTINCT_CHOICE, + ThetaSketchApproxCountDistinctSqlAggregator.NAME + ); } @Override @@ -1018,8 +1033,160 @@ public class ThetaSketchSqlAggregatorTest extends BaseCalciteQueryTest @Test public void testThetaSketchIntersectOnScalarExpression() { - assertQueryIsUnplannable("SELECT THETA_SKETCH_INTERSECT(NULL, NULL) FROM foo", + assertQueryIsUnplannable( + "SELECT THETA_SKETCH_INTERSECT(NULL, NULL) FROM foo", "Possible error: THETA_SKETCH_INTERSECT can only be used on aggregates. " + - "It cannot be used directly on a column or on a scalar expression."); + "It cannot be used directly on a column or on a scalar expression." + ); + } + + @Test + public void testThetaSketchEstimateAsVirtualColumn() + { + testQuery( + "SELECT" + + " THETA_SKETCH_ESTIMATE(thetasketch_dim1)" + + " FROM foo", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(new ExpressionVirtualColumn( + "v0", + "theta_sketch_estimate(\"thetasketch_dim1\")", + ColumnType.DOUBLE, + MACRO_TABLE + )) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .columns("v0") + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + NullHandling.replaceWithDefault() ? new Object[]{null} : new Object[]{0.0D}, + new Object[]{1.0D}, + new Object[]{1.0D}, + new Object[]{1.0D}, + new Object[]{1.0D}, + new Object[]{1.0D} + ) + ); + } + + @Test + public void testThetaEstimateAsVirtualColumnOnNonThetaCol() + { + try { + testQuery( + "SELECT" + + " THETA_SKETCH_ESTIMATE(dim2)" + + " FROM druid.foo", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(new ExpressionVirtualColumn( + "v0", + "theta_sketch_estimate(\"dim2\")", + ColumnType.DOUBLE, + MACRO_TABLE + )) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .columns("v0") + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of() + ); + } + catch (IllegalArgumentException e) { + Assert.assertTrue( + e.getMessage().contains("requires a ThetaSketch as the argument") + ); + } + } + + @Test + public void testThetaEstimateAsVirtualColumnWithGroupByOrderBy() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "SELECT" + + " THETA_SKETCH_ESTIMATE(thetasketch_dim1), count(*)" + + " FROM druid.foo" + + " GROUP BY 1" + + " ORDER BY 2 DESC", + ImmutableList.of( + GroupByQuery.builder() + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDataSource(CalciteTests.DATASOURCE1) + .setGranularity(Granularities.ALL) + .setVirtualColumns(new ExpressionVirtualColumn( + "v0", + "theta_sketch_estimate(\"thetasketch_dim1\")", + ColumnType.DOUBLE, + MACRO_TABLE + )) + .setDimensions( + new DefaultDimensionSpec("v0", "d0", ColumnType.DOUBLE)) + .setAggregatorSpecs( + aggregators( + new CountAggregatorFactory("a0") + ) + ) + .setLimitSpec( + DefaultLimitSpec + .builder() + .orderBy( + new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + ) + ) + .build() + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{1.0D, 5L}, + new Object[]{0.0D, 1L} + ) + ); + } + + @Test + public void testThetaEstimateAsVirtualColumnWithTopN() + { + testQuery( + "SELECT" + + " THETA_SKETCH_ESTIMATE(thetasketch_dim1)" + + " FROM druid.foo" + + " GROUP BY 1 ORDER BY 1" + + " LIMIT 2", + ImmutableList.of( + new TopNQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .dimension(new DefaultDimensionSpec("v0", "d0", ColumnType.DOUBLE)) + .virtualColumns(new ExpressionVirtualColumn( + "v0", + "theta_sketch_estimate(\"thetasketch_dim1\")", + ColumnType.DOUBLE, + MACRO_TABLE + )) + .metric(new DimensionTopNMetricSpec(null, StringComparators.NUMERIC)) + .threshold(2) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{0.0D}, + new Object[]{1.0D} + ) + ); } }