Fix CAST being ignored when aggregating on strings after cast (#11083)

* Fix CAST being ignored when aggregating on strings after cast

* fix checkstyle and dependency

* unused import
This commit is contained in:
Jihoon Son 2021-04-12 22:21:24 -07:00 committed by GitHub
parent 0e0c1a1aaf
commit 25db8787b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 874 additions and 600 deletions

View File

@ -124,7 +124,7 @@ public class JoinAndLookupBenchmark
{
tmpDir = FileUtils.createTempDir();
ColumnConfig columnConfig = () -> columnCacheSizeBytes;
index = JoinTestHelper.createFactIndexBuilder(tmpDir, rows).buildMMappedIndex(columnConfig);
index = JoinTestHelper.createFactIndexBuilder(columnConfig, tmpDir, rows).buildMMappedIndex();
final String prefix = "c.";

View File

@ -133,6 +133,11 @@
<scope>provided</scope>
<version>${project.parent.version}</version>
</dependency>
<dependency>
<groupId>joda-time</groupId>
<artifactId>joda-time</artifactId>
<scope>provided</scope>
</dependency>
<!-- Test Dependencies -->
<dependency>

View File

@ -39,6 +39,7 @@ import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
@ -79,7 +80,7 @@ public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
project,
aggregateCall.getArgList().get(0)
);
final DruidExpression input = Expressions.toDruidExpression(
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
inputOperand

View File

@ -42,6 +42,7 @@ import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
@ -78,7 +79,7 @@ public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
)
{
// This is expected to be a tdigest sketch
final DruidExpression input = Expressions.toDruidExpression(
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(

View File

@ -24,12 +24,11 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
@ -42,73 +41,40 @@ import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AuthTestUtils;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.sql.SqlLifecycle;
import org.apache.druid.sql.SqlLifecycleFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.planner.PlannerFactory;
import org.apache.druid.sql.calcite.util.CalciteTestBase;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.QueryLogHook;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.sql.http.SqlParameter;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
public class TDigestSketchSqlAggregatorTest extends BaseCalciteQueryTest
{
private static final String DATA_SOURCE = "foo";
private static QueryRunnerFactoryConglomerate conglomerate;
private static Closer resourceCloser;
private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final Map<String, Object> QUERY_CONTEXT_DEFAULT = ImmutableMap.of(
PlannerContext.CTX_SQL_QUERY_ID, "dummy"
private static final AuthenticationResult AUTH_RESULT = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final DruidOperatorTable OPERATOR_TABLE = new DruidOperatorTable(
ImmutableSet.of(new TDigestSketchQuantileSqlAggregator(), new TDigestGenerateSketchSqlAggregator()),
ImmutableSet.of()
);
@BeforeClass
public static void setUpClass()
{
resourceCloser = Closer.create();
conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser);
}
@AfterClass
public static void tearDownClass() throws IOException
{
resourceCloser.close();
}
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule
public QueryLogHook queryLogHook = QueryLogHook.create();
private SpecificSegmentsQuerySegmentWalker walker;
private SqlLifecycleFactory sqlLifecycleFactory;
@Before
public void setUp() throws Exception
@Override
public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker() throws IOException
{
TDigestSketchModule.registerSerde();
for (Module mod : new TDigestSketchModule().getJacksonModules()) {
@ -116,7 +82,7 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
}
final QueryableIndex index =
IndexBuilder.create()
IndexBuilder.create(CalciteTests.getJsonMapper())
.tmpDir(temporaryFolder.newFolder())
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(
@ -136,9 +102,9 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
.rows(CalciteTests.ROWS1)
.buildMMappedIndex();
walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
return new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
DataSegment.builder()
.dataSource(DATA_SOURCE)
.dataSource(CalciteTests.DATASOURCE1)
.interval(index.getDataInterval())
.version("1")
.shardSpec(new LinearShardSpec(0))
@ -146,39 +112,45 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
.build(),
index
);
}
final PlannerConfig plannerConfig = new PlannerConfig();
final DruidOperatorTable operatorTable = new DruidOperatorTable(
ImmutableSet.of(new TDigestSketchQuantileSqlAggregator(), new TDigestGenerateSketchSqlAggregator()),
ImmutableSet.of()
);
SchemaPlus rootSchema =
CalciteTests.createMockRootSchema(conglomerate, walker, plannerConfig, AuthTestUtils.TEST_AUTHORIZER_MAPPER);
sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(
new PlannerFactory(
rootSchema,
CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
operatorTable,
CalciteTests.createExprMacroTable(),
plannerConfig,
AuthTestUtils.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper(),
CalciteTests.DRUID_SCHEMA_NAME
)
@Override
public List<Object[]> getResults(
final PlannerConfig plannerConfig,
final Map<String, Object> queryContext,
final List<SqlParameter> parameters,
final String sql,
final AuthenticationResult authenticationResult
) throws Exception
{
return getResults(
plannerConfig,
queryContext,
parameters,
sql,
authenticationResult,
OPERATOR_TABLE,
CalciteTests.createExprMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
);
}
@After
public void tearDown() throws Exception
private SqlLifecycle getSqlLifecycle()
{
walker.close();
walker = null;
return getSqlLifecycleFactory(
BaseCalciteQueryTest.PLANNER_CONFIG_DEFAULT,
OPERATOR_TABLE,
CalciteTests.createExprMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
).factorize();
}
@Test
public void testComputingSketchOnNumericValues() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "TDIGEST_GENERATE_SKETCH(m1, 200)"
+ "FROM foo";
@ -186,9 +158,9 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
// Verify results
final List<Object[]> results = sqlLifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<String[]> expectedResults = ImmutableList.of(
new String[]{
@ -207,16 +179,56 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
.aggregators(ImmutableList.of(
new TDigestSketchAggregatorFactory("a0:agg", "m1", 200)
))
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@Test
public void testComputingSketchOnCastedString() throws Exception
{
cannotVectorize();
testQuery(
"SELECT\n"
+ "TDIGEST_GENERATE_SKETCH(CAST(dim1 AS DOUBLE), 200)"
+ "FROM foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.virtualColumns(
new ExpressionVirtualColumn(
"v0",
"CAST(\"dim1\", 'DOUBLE')",
ValueType.FLOAT,
ExprMacroTable.nil()
)
)
.aggregators(ImmutableList.of(
new TDigestSketchAggregatorFactory("a0:agg", "v0", 200)
))
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.build()
),
ImmutableList.of(
NullHandling.replaceWithDefault()
? new String[]{
"\"AAAAAQAAAAAAAAAAQCQzMzMzMzNAaQAAAAAAAAAAAAY/8AAAAAAAAAAAAAAAAAAAP/AAAAAAAAAAAAAAAAAAAD/wAAAAAAAAAAAAAAAAAAA/8AAAAAAAAD/wAAAAAAAAP/AAAAAAAABAAAAAAAAAAD/wAAAAAAAAQCQzMzMzMzM=\""
}
: new String[]{
"\"AAAAAT/wAAAAAAAAQCQzMzMzMzNAaQAAAAAAAAAAAAM/8AAAAAAAAD/wAAAAAAAAP/AAAAAAAABAAAAAAAAAAD/wAAAAAAAAQCQzMzMzMzM=\""
}
)
);
}
@Test
public void testDefaultCompressionForTDigestGenerateSketchAgg() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "TDIGEST_GENERATE_SKETCH(m1)"
+ "FROM foo";
@ -224,9 +236,9 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
// Log query
sqlLifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
// Verify query
@ -238,7 +250,7 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
.aggregators(ImmutableList.of(
new TDigestSketchAggregatorFactory("a0:agg", "m1", TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION)
))
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
@ -247,7 +259,7 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
@Test
public void testComputingQuantileOnPreAggregatedSketch() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "TDIGEST_QUANTILE(qsketch_m1, 0.1),\n"
+ "TDIGEST_QUANTILE(qsketch_m1, 0.4),\n"
@ -258,9 +270,9 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
// Verify results
final List<Object[]> results = sqlLifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<double[]> expectedResults = ImmutableList.of(
new double[]{
@ -296,7 +308,7 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
new TDigestSketchToQuantilePostAggregator("a2", makeFieldAccessPostAgg("a0:agg"), 0.8f),
new TDigestSketchToQuantilePostAggregator("a3", makeFieldAccessPostAgg("a0:agg"), 1.0f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
@ -305,16 +317,16 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
@Test
public void testGeneratingSketchAndComputingQuantileOnFly() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT TDIGEST_QUANTILE(x, 0.0), TDIGEST_QUANTILE(x, 0.5), TDIGEST_QUANTILE(x, 1.0)\n"
+ "FROM (SELECT dim1, TDIGEST_GENERATE_SKETCH(m1, 200) AS x FROM foo group by dim1)";
// Verify results
final List<Object[]> results = sqlLifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<double[]> expectedResults = ImmutableList.of(
new double[]{
@ -348,7 +360,7 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
new TDigestSketchAggregatorFactory("a0:agg", "m1", 200)
)
)
.setContext(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.setContext(TIMESERIES_CONTEXT_DEFAULT)
.build()
)
)
@ -366,7 +378,7 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
new TDigestSketchToQuantilePostAggregator("_a2", makeFieldAccessPostAgg("_a0:agg"), 1.0f)
)
)
.setContext(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.setContext(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
@ -375,7 +387,7 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
@Test
public void testQuantileOnNumericValues() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "TDIGEST_QUANTILE(m1, 0.0), TDIGEST_QUANTILE(m1, 0.5), TDIGEST_QUANTILE(m1, 1.0)\n"
+ "FROM foo";
@ -383,9 +395,9 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
// Verify results
final List<Object[]> results = sqlLifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<double[]> expectedResults = ImmutableList.of(
new double[]{
@ -418,7 +430,7 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
new TDigestSketchToQuantilePostAggregator("a1", makeFieldAccessPostAgg("a0:agg"), 0.5f),
new TDigestSketchToQuantilePostAggregator("a2", makeFieldAccessPostAgg("a0:agg"), 1.0f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
@ -427,7 +439,7 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
@Test
public void testCompressionParamForTDigestQuantileAgg() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "TDIGEST_QUANTILE(m1, 0.0), TDIGEST_QUANTILE(m1, 0.5, 200), TDIGEST_QUANTILE(m1, 1.0, 300)\n"
+ "FROM foo";
@ -435,9 +447,9 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
// Log query
sqlLifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
// Verify query
@ -462,12 +474,64 @@ public class TDigestSketchSqlAggregatorTest extends CalciteTestBase
new TDigestSketchToQuantilePostAggregator("a1", makeFieldAccessPostAgg("a1:agg"), 0.5f),
new TDigestSketchToQuantilePostAggregator("a2", makeFieldAccessPostAgg("a2:agg"), 1.0f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@Test
public void testQuantileOnCastedString() throws Exception
{
cannotVectorize();
testQuery(
"SELECT\n"
+ " TDIGEST_QUANTILE(CAST(dim1 AS DOUBLE), 0.0),\n"
+ " TDIGEST_QUANTILE(CAST(dim1 AS DOUBLE), 0.5),\n"
+ " TDIGEST_QUANTILE(CAST(dim1 AS DOUBLE), 1.0)\n"
+ "FROM foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(
new ExpressionVirtualColumn(
"v0",
"CAST(\"dim1\", 'DOUBLE')",
ValueType.FLOAT,
ExprMacroTable.nil()
)
)
.aggregators(new TDigestSketchAggregatorFactory("a0:agg", "v0", 100))
.postAggregators(
new TDigestSketchToQuantilePostAggregator(
"a0",
new FieldAccessPostAggregator("a0:agg", "a0:agg"),
0.0
),
new TDigestSketchToQuantilePostAggregator(
"a1",
new FieldAccessPostAggregator("a0:agg", "a0:agg"),
0.5
),
new TDigestSketchToQuantilePostAggregator(
"a2",
new FieldAccessPostAggregator("a0:agg", "a0:agg"),
1.0
)
)
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
NullHandling.replaceWithDefault()
? new Object[]{0.0, 0.5, 10.1}
: new Object[]{1.0, 2.0, 10.1}
)
);
}
private static PostAggregator makeFieldAccessPostAgg(String name)
{

View File

@ -42,6 +42,7 @@ import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
@ -77,7 +78,7 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
final boolean finalizeAggregations
)
{
final DruidExpression input = Expressions.toDruidExpression(
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(

View File

@ -39,6 +39,7 @@ import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
@ -74,7 +75,7 @@ public class DoublesSketchObjectSqlAggregator implements SqlAggregator
final boolean finalizeAggregations
)
{
final DruidExpression input = Expressions.toDruidExpression(
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(

View File

@ -24,15 +24,12 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.query.Druids;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
@ -57,84 +54,57 @@ import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AuthTestUtils;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.sql.SqlLifecycle;
import org.apache.druid.sql.SqlLifecycleFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.planner.PlannerFactory;
import org.apache.druid.sql.calcite.util.CalciteTestBase;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.QueryLogHook;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.sql.http.SqlParameter;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.IOException;
import java.util.List;
import java.util.Map;
public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
public class DoublesSketchSqlAggregatorTest extends BaseCalciteQueryTest
{
private static final String DATA_SOURCE = "foo";
private static QueryRunnerFactoryConglomerate conglomerate;
private static Closer resourceCloser;
private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final Map<String, Object> QUERY_CONTEXT_DEFAULT = ImmutableMap.of(
PlannerContext.CTX_SQL_QUERY_ID, "dummy"
private static final AuthenticationResult AUTH_RESULT = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final DruidOperatorTable OPERATOR_TABLE = new DruidOperatorTable(
ImmutableSet.of(
new DoublesSketchApproxQuantileSqlAggregator(),
new DoublesSketchObjectSqlAggregator()
),
ImmutableSet.of(
new DoublesSketchQuantileOperatorConversion(),
new DoublesSketchQuantilesOperatorConversion(),
new DoublesSketchToHistogramOperatorConversion(),
new DoublesSketchRankOperatorConversion(),
new DoublesSketchCDFOperatorConversion(),
new DoublesSketchSummaryOperatorConversion()
)
);
@BeforeClass
public static void setUpClass()
{
resourceCloser = Closer.create();
conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser);
}
@AfterClass
public static void tearDownClass() throws IOException
{
resourceCloser.close();
}
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule
public QueryLogHook queryLogHook = QueryLogHook.create();
private SpecificSegmentsQuerySegmentWalker walker;
private SqlLifecycleFactory sqlLifecycleFactory;
@Before
public void setUp() throws Exception
@Override
public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker() throws IOException
{
DoublesSketchModule.registerSerde();
for (Module mod : new DoublesSketchModule().getJacksonModules()) {
CalciteTests.getJsonMapper().registerModule(mod);
TestHelper.JSON_MAPPER.registerModule(mod);
}
final QueryableIndex index =
IndexBuilder.create()
IndexBuilder.create(CalciteTests.getJsonMapper())
.tmpDir(temporaryFolder.newFolder())
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(
@ -154,9 +124,9 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
.rows(CalciteTests.ROWS1)
.buildMMappedIndex();
walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
return new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
DataSegment.builder()
.dataSource(DATA_SOURCE)
.dataSource(CalciteTests.DATASOURCE1)
.interval(index.getDataInterval())
.version("1")
.shardSpec(new LinearShardSpec(0))
@ -164,50 +134,45 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
.build(),
index
);
}
final PlannerConfig plannerConfig = new PlannerConfig();
final DruidOperatorTable operatorTable = new DruidOperatorTable(
ImmutableSet.of(
new DoublesSketchApproxQuantileSqlAggregator(),
new DoublesSketchObjectSqlAggregator()
),
ImmutableSet.of(
new DoublesSketchQuantileOperatorConversion(),
new DoublesSketchQuantilesOperatorConversion(),
new DoublesSketchToHistogramOperatorConversion(),
new DoublesSketchRankOperatorConversion(),
new DoublesSketchCDFOperatorConversion(),
new DoublesSketchSummaryOperatorConversion()
)
);
SchemaPlus rootSchema =
CalciteTests.createMockRootSchema(conglomerate, walker, plannerConfig, AuthTestUtils.TEST_AUTHORIZER_MAPPER);
sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(
new PlannerFactory(
rootSchema,
CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
operatorTable,
CalciteTests.createExprMacroTable(),
plannerConfig,
AuthTestUtils.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper(),
CalciteTests.DRUID_SCHEMA_NAME
)
@Override
public List<Object[]> getResults(
final PlannerConfig plannerConfig,
final Map<String, Object> queryContext,
final List<SqlParameter> parameters,
final String sql,
final AuthenticationResult authenticationResult
) throws Exception
{
return getResults(
plannerConfig,
queryContext,
parameters,
sql,
authenticationResult,
OPERATOR_TABLE,
CalciteTests.createExprMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
);
}
@After
public void tearDown() throws Exception
private SqlLifecycle getSqlLifecycle()
{
walker.close();
walker = null;
return getSqlLifecycleFactory(
BaseCalciteQueryTest.PLANNER_CONFIG_DEFAULT,
OPERATOR_TABLE,
CalciteTests.createExprMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
).factorize();
}
@Test
public void testQuantileOnFloatAndLongs() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "APPROX_QUANTILE_DS(m1, 0.01),\n"
+ "APPROX_QUANTILE_DS(m1, 0.5, 64),\n"
@ -223,9 +188,9 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
// Verify results
final List<Object[]> results = sqlLifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
@ -285,7 +250,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
new DoublesSketchToQuantilePostAggregator("a7", makeFieldAccessPostAgg("a5:agg"), 0.999f),
new DoublesSketchToQuantilePostAggregator("a8", makeFieldAccessPostAgg("a8:agg"), 0.50f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
@ -294,7 +259,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
@Test
public void testQuantileOnComplexColumn() throws Exception
{
SqlLifecycle lifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle lifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.01),\n"
+ "APPROX_QUANTILE_DS(qsketch_m1, 0.5, 64),\n"
@ -308,9 +273,9 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
// Verify results
final List<Object[]> results = lifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
@ -356,16 +321,111 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
new DoublesSketchToQuantilePostAggregator("a5", makeFieldAccessPostAgg("a5:agg"), 0.999f),
new DoublesSketchToQuantilePostAggregator("a6", makeFieldAccessPostAgg("a4:agg"), 0.999f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@Test
public void testQuantileOnCastedString() throws Exception
{
cannotVectorize();
final List<Object[]> expectedResults;
if (NullHandling.replaceWithDefault()) {
expectedResults = ImmutableList.of(
new Object[]{
0.0,
1.0,
10.1,
10.1,
20.2,
0.0,
10.1,
0.0
}
);
} else {
expectedResults = ImmutableList.of(
new Object[]{
1.0,
2.0,
10.1,
10.1,
20.2,
Double.NaN,
10.1,
Double.NaN
}
);
}
testQuery(
"SELECT\n"
+ "APPROX_QUANTILE_DS(CAST(dim1 as DOUBLE), 0.01),\n"
+ "APPROX_QUANTILE_DS(CAST(dim1 as DOUBLE), 0.5, 64),\n"
+ "APPROX_QUANTILE_DS(CAST(dim1 as DOUBLE), 0.98, 256),\n"
+ "APPROX_QUANTILE_DS(CAST(dim1 as DOUBLE), 0.99),\n"
+ "APPROX_QUANTILE_DS(CAST(dim1 as DOUBLE) * 2, 0.97),\n"
+ "APPROX_QUANTILE_DS(CAST(dim1 as DOUBLE), 0.99) FILTER(WHERE dim2 = 'abc'),\n"
+ "APPROX_QUANTILE_DS(CAST(dim1 as DOUBLE), 0.999) FILTER(WHERE dim2 <> 'abc'),\n"
+ "APPROX_QUANTILE_DS(CAST(dim1 as DOUBLE), 0.999) FILTER(WHERE dim2 = 'abc')\n"
+ "FROM foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.virtualColumns(
new ExpressionVirtualColumn(
"v0",
"CAST(\"dim1\", 'DOUBLE')",
ValueType.FLOAT,
TestExprMacroTable.INSTANCE
),
new ExpressionVirtualColumn(
"v1",
"(CAST(\"dim1\", 'DOUBLE') * 2)",
ValueType.FLOAT,
TestExprMacroTable.INSTANCE
)
)
.aggregators(ImmutableList.of(
new DoublesSketchAggregatorFactory("a0:agg", "v0", 128),
new DoublesSketchAggregatorFactory("a1:agg", "v0", 64),
new DoublesSketchAggregatorFactory("a2:agg", "v0", 256),
new DoublesSketchAggregatorFactory("a4:agg", "v1", 128),
new FilteredAggregatorFactory(
new DoublesSketchAggregatorFactory("a5:agg", "v0", 128),
new SelectorDimFilter("dim2", "abc", null)
),
new FilteredAggregatorFactory(
new DoublesSketchAggregatorFactory("a6:agg", "v0", 128),
new NotDimFilter(new SelectorDimFilter("dim2", "abc", null))
)
))
.postAggregators(
new DoublesSketchToQuantilePostAggregator("a0", makeFieldAccessPostAgg("a0:agg"), 0.01f),
new DoublesSketchToQuantilePostAggregator("a1", makeFieldAccessPostAgg("a1:agg"), 0.50f),
new DoublesSketchToQuantilePostAggregator("a2", makeFieldAccessPostAgg("a2:agg"), 0.98f),
new DoublesSketchToQuantilePostAggregator("a3", makeFieldAccessPostAgg("a0:agg"), 0.99f),
new DoublesSketchToQuantilePostAggregator("a4", makeFieldAccessPostAgg("a4:agg"), 0.97f),
new DoublesSketchToQuantilePostAggregator("a5", makeFieldAccessPostAgg("a5:agg"), 0.99f),
new DoublesSketchToQuantilePostAggregator("a6", makeFieldAccessPostAgg("a6:agg"), 0.999f),
new DoublesSketchToQuantilePostAggregator("a7", makeFieldAccessPostAgg("a5:agg"), 0.999f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.build()
),
expectedResults
);
}
@Test
public void testQuantileOnInnerQuery() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT AVG(x), APPROX_QUANTILE_DS(x, 0.98)\n"
+ "FROM (SELECT dim2, SUM(m1) AS x FROM foo GROUP BY dim2)";
@ -374,7 +434,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
sql,
QUERY_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults;
if (NullHandling.replaceWithDefault()) {
@ -402,7 +462,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
new DoubleSumAggregatorFactory("a0", "m1")
)
)
.setContext(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
)
)
@ -430,7 +490,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
new DoublesSketchToQuantilePostAggregator("_a1", makeFieldAccessPostAgg("_a1:agg"), 0.98f)
)
)
.setContext(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.setContext(QUERY_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
@ -439,7 +499,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
@Test
public void testQuantileOnInnerQuantileQuery() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT dim1, APPROX_QUANTILE_DS(x, 0.5)\n"
+ "FROM (SELECT dim1, dim2, APPROX_QUANTILE_DS(m1, 0.5) AS x FROM foo GROUP BY dim1, dim2) GROUP BY dim1";
@ -448,7 +508,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
sql,
QUERY_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
ImmutableList.Builder<Object[]> builder = ImmutableList.builder();
@ -491,7 +551,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
)
)
)
.setContext(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
)
)
@ -506,7 +566,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
new DoublesSketchToQuantilePostAggregator("_a0", makeFieldAccessPostAgg("_a0:agg"), 0.5f)
)
)
.setContext(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.setContext(QUERY_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
@ -515,7 +575,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
@Test
public void testDoublesSketchPostAggs() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ " SUM(cnt),\n"
+ " APPROX_QUANTILE_DS(cnt, 0.5) + 1,\n"
@ -532,9 +592,9 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
// Verify results
final List<Object[]> results = sqlLifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
@ -682,12 +742,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
)
)
)
.context(ImmutableMap.of(
"skipEmptyBuckets",
true,
PlannerContext.CTX_SQL_QUERY_ID,
"dummy"
))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build();
// Verify query
@ -697,7 +752,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
@Test
public void testDoublesSketchPostAggsPostSort() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT DS_QUANTILES_SKETCH(m1) as y FROM druid.foo ORDER BY DS_GET_QUANTILE(DS_QUANTILES_SKETCH(m1), 0.5) DESC LIMIT 10";
final String sql2 = StringUtils.format("SELECT DS_GET_QUANTILE(y, 0.5), DS_GET_QUANTILE(y, 0.98) from (%s)", sql);
@ -705,9 +760,9 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
// Verify results
final List<Object[]> results = sqlLifecycle.runSimple(
sql2,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
@ -749,10 +804,7 @@ public class DoublesSketchSqlAggregatorTest extends CalciteTestBase
)
)
)
.context(ImmutableMap.of(
"skipEmptyBuckets", true,
PlannerContext.CTX_SQL_QUERY_ID, "dummy"
))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build();
// Verify query

View File

@ -45,6 +45,9 @@ public class ApproximateHistogramBufferAggregator implements BufferAggregator
@Override
public void aggregate(ByteBuffer buf, int position)
{
if (selector.isNull()) {
return;
}
innerAggregator.aggregate(buf, position, selector.getFloat());
}

View File

@ -42,6 +42,7 @@ import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
@ -77,7 +78,7 @@ public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
boolean finalizeAggregations
)
{
final DruidExpression input = Expressions.toDruidExpression(
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(

View File

@ -43,6 +43,7 @@ import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
@ -78,7 +79,7 @@ public class QuantileSqlAggregator implements SqlAggregator
final boolean finalizeAggregations
)
{
final DruidExpression input = Expressions.toDruidExpression(
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(

View File

@ -24,14 +24,11 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.query.Druids;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
@ -53,77 +50,45 @@ import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AuthTestUtils;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.sql.SqlLifecycle;
import org.apache.druid.sql.SqlLifecycleFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.planner.PlannerFactory;
import org.apache.druid.sql.calcite.util.CalciteTestBase;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.QueryLogHook;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.sql.http.SqlParameter;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.IOException;
import java.util.List;
import java.util.Map;
public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestBase
public class FixedBucketsHistogramQuantileSqlAggregatorTest extends BaseCalciteQueryTest
{
private static final String DATA_SOURCE = "foo";
private static QueryRunnerFactoryConglomerate conglomerate;
private static Closer resourceCloser;
private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final AuthenticationResult AUTH_RESULT = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final Map<String, Object> QUERY_CONTEXT_DEFAULT = ImmutableMap.of(
PlannerContext.CTX_SQL_QUERY_ID, "dummy"
);
private static final DruidOperatorTable OPERATOR_TABLE = new DruidOperatorTable(
ImmutableSet.of(new QuantileSqlAggregator(), new FixedBucketsHistogramQuantileSqlAggregator()),
ImmutableSet.of()
);
@BeforeClass
public static void setUpClass()
{
resourceCloser = Closer.create();
conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser);
}
@AfterClass
public static void tearDownClass() throws IOException
{
resourceCloser.close();
}
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule
public QueryLogHook queryLogHook = QueryLogHook.create();
private SpecificSegmentsQuerySegmentWalker walker;
private SqlLifecycleFactory sqlLifecycleFactory;
@Before
public void setUp() throws Exception
@Override
public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker() throws IOException
{
ApproximateHistogramDruidModule.registerSerde();
for (Module mod : new ApproximateHistogramDruidModule().getJacksonModules()) {
CalciteTests.getJsonMapper().registerModule(mod);
}
final QueryableIndex index = IndexBuilder.create()
final QueryableIndex index = IndexBuilder.create(CalciteTests.getJsonMapper())
.tmpDir(temporaryFolder.newFolder())
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(
@ -147,9 +112,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
.rows(CalciteTests.ROWS1)
.buildMMappedIndex();
walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
return new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
DataSegment.builder()
.dataSource(DATA_SOURCE)
.dataSource(CalciteTests.DATASOURCE1)
.interval(index.getDataInterval())
.version("1")
.shardSpec(new LinearShardSpec(0))
@ -157,40 +122,45 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
.build(),
index
);
}
final PlannerConfig plannerConfig = new PlannerConfig();
final DruidOperatorTable operatorTable = new DruidOperatorTable(
ImmutableSet.of(new QuantileSqlAggregator(), new FixedBucketsHistogramQuantileSqlAggregator()),
ImmutableSet.of()
);
SchemaPlus rootSchema =
CalciteTests.createMockRootSchema(conglomerate, walker, plannerConfig, AuthTestUtils.TEST_AUTHORIZER_MAPPER);
sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(
new PlannerFactory(
rootSchema,
CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
operatorTable,
CalciteTests.createExprMacroTable(),
plannerConfig,
AuthTestUtils.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper(),
CalciteTests.DRUID_SCHEMA_NAME
)
@Override
public List<Object[]> getResults(
final PlannerConfig plannerConfig,
final Map<String, Object> queryContext,
final List<SqlParameter> parameters,
final String sql,
final AuthenticationResult authenticationResult
) throws Exception
{
return getResults(
plannerConfig,
queryContext,
parameters,
sql,
authenticationResult,
OPERATOR_TABLE,
CalciteTests.createExprMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
);
}
@After
public void tearDown() throws Exception
private SqlLifecycle getSqlLifecycle()
{
walker.close();
walker = null;
return getSqlLifecycleFactory(
BaseCalciteQueryTest.PLANNER_CONFIG_DEFAULT,
OPERATOR_TABLE,
CalciteTests.createExprMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
).factorize();
}
@Test
public void testQuantileOnFloatAndLongs() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.01, 20, 0.0, 10.0),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(m1, 0.5, 20, 0.0, 10.0),\n"
@ -206,9 +176,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
// Verify results
final List<Object[]> results = sqlLifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
@ -305,7 +275,7 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
new QuantilePostAggregator("a7", "a5:agg", 0.999f),
new QuantilePostAggregator("a8", "a8:agg", 0.50f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build();
// Verify query
@ -315,10 +285,128 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
);
}
@Test
public void testQuentileOnCastedString() throws Exception
{
cannotVectorize();
testQuery(
"SELECT\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(CAST(dim1 AS DOUBLE), 0.01, 20, 0.0, 10.0),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(CAST(dim1 AS DOUBLE), 0.5, 20, 0.0, 10.0),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(CAST(dim1 AS DOUBLE), 0.98, 20, 0.0, 10.0),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(CAST(dim1 AS DOUBLE), 0.99, 20, 0.0, 10.0),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(CAST(dim1 AS DOUBLE) * 2, 0.97, 40, 0.0, 20.0),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(CAST(dim1 AS DOUBLE), 0.99, 20, 0.0, 10.0) FILTER(WHERE dim1 = 'abc'),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(CAST(dim1 AS DOUBLE), 0.999, 20, 0.0, 10.0) FILTER(WHERE dim1 <> 'abc'),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(CAST(dim1 AS DOUBLE), 0.999, 20, 0.0, 10.0) FILTER(WHERE dim1 = 'abc')\n"
+ "FROM foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.virtualColumns(
new ExpressionVirtualColumn(
"v0",
"CAST(\"dim1\", 'DOUBLE')",
ValueType.FLOAT,
TestExprMacroTable.INSTANCE
),
new ExpressionVirtualColumn(
"v1",
"(CAST(\"dim1\", 'DOUBLE') * 2)",
ValueType.FLOAT,
TestExprMacroTable.INSTANCE
)
)
.aggregators(ImmutableList.of(
new FixedBucketsHistogramAggregatorFactory(
"a0:agg",
"v0",
20,
0.0d,
10.0d,
FixedBucketsHistogram.OutlierHandlingMode.IGNORE,
false
),
new FixedBucketsHistogramAggregatorFactory(
"a4:agg",
"v1",
40,
0.0d,
20.0d,
FixedBucketsHistogram.OutlierHandlingMode.IGNORE,
false
),
new FilteredAggregatorFactory(
new FixedBucketsHistogramAggregatorFactory(
"a5:agg",
"v0",
20,
0.0d,
10.0d,
FixedBucketsHistogram.OutlierHandlingMode.IGNORE,
false
),
new SelectorDimFilter("dim1", "abc", null)
),
new FilteredAggregatorFactory(
new FixedBucketsHistogramAggregatorFactory(
"a6:agg",
"v0",
20,
0.0d,
10.0d,
FixedBucketsHistogram.OutlierHandlingMode.IGNORE,
false
),
new NotDimFilter(new SelectorDimFilter("dim1", "abc", null))
)
))
.postAggregators(
new QuantilePostAggregator("a0", "a0:agg", 0.01f),
new QuantilePostAggregator("a1", "a0:agg", 0.50f),
new QuantilePostAggregator("a2", "a0:agg", 0.98f),
new QuantilePostAggregator("a3", "a0:agg", 0.99f),
new QuantilePostAggregator("a4", "a4:agg", 0.97f),
new QuantilePostAggregator("a5", "a5:agg", 0.99f),
new QuantilePostAggregator("a6", "a6:agg", 0.999f),
new QuantilePostAggregator("a7", "a5:agg", 0.999f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.build()
),
ImmutableList.of(
NullHandling.replaceWithDefault()
? new Object[]{
0.00833333283662796,
0.4166666567325592,
2.450000047683716,
2.4749999046325684,
4.425000190734863,
0.4950000047683716,
2.498000144958496,
0.49950000643730164
}
: new Object[]{
1.0099999904632568,
1.5,
2.4800000190734863,
2.490000009536743,
4.470000267028809,
0.0,
2.499000072479248,
0.0
}
)
);
}
@Test
public void testQuantileOnComplexColumn() throws Exception
{
SqlLifecycle lifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle lifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(fbhist_m1, 0.01, 20, 0.0, 10.0),\n"
+ "APPROX_QUANTILE_FIXED_BUCKETS(fbhist_m1, 0.5, 20, 0.0, 10.0),\n"
@ -332,9 +420,9 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
// Verify results
final List<Object[]> results = lifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
@ -410,7 +498,7 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
new QuantilePostAggregator("a5", "a5:agg", 0.999f),
new QuantilePostAggregator("a6", "a4:agg", 0.999f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build();
// Verify query
@ -420,7 +508,7 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
@Test
public void testQuantileOnInnerQuery() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT AVG(x), APPROX_QUANTILE_FIXED_BUCKETS(x, 0.98, 100, 0.0, 100.0)\n"
+ "FROM (SELECT dim2, SUM(m1) AS x FROM foo GROUP BY dim2)";
@ -429,7 +517,7 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
sql,
QUERY_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults;
if (NullHandling.replaceWithDefault()) {
@ -457,10 +545,7 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
new DoubleSumAggregatorFactory("a0", "m1")
)
)
.setContext(ImmutableMap.of(
PlannerContext.CTX_SQL_QUERY_ID,
"dummy"
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
)
)
@ -492,7 +577,7 @@ public class FixedBucketsHistogramQuantileSqlAggregatorTest extends CalciteTestB
new QuantilePostAggregator("_a1", "_a1:agg", 0.98f)
)
)
.setContext(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.setContext(QUERY_CONTEXT_DEFAULT)
.build();
// Verify query

View File

@ -24,13 +24,11 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
@ -48,83 +46,51 @@ import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AuthTestUtils;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.sql.SqlLifecycle;
import org.apache.druid.sql.SqlLifecycleFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.planner.PlannerFactory;
import org.apache.druid.sql.calcite.util.CalciteTestBase;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.QueryLogHook;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.sql.http.SqlParameter;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.IOException;
import java.util.List;
import java.util.Map;
public class QuantileSqlAggregatorTest extends CalciteTestBase
public class QuantileSqlAggregatorTest extends BaseCalciteQueryTest
{
private static final String DATA_SOURCE = "foo";
private static QueryRunnerFactoryConglomerate conglomerate;
private static Closer resourceCloser;
private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final AuthenticationResult AUTH_RESULT = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final Map<String, Object> QUERY_CONTEXT_DEFAULT = ImmutableMap.of(
PlannerContext.CTX_SQL_QUERY_ID, "dummy"
PlannerContext.CTX_SQL_QUERY_ID,
"dummy"
);
@BeforeClass
public static void setUpClass()
{
resourceCloser = Closer.create();
conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser);
}
private static final DruidOperatorTable OPERATOR_TABLE = new DruidOperatorTable(
ImmutableSet.of(new QuantileSqlAggregator()),
ImmutableSet.of()
);
@AfterClass
public static void tearDownClass() throws IOException
{
resourceCloser.close();
}
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule
public QueryLogHook queryLogHook = QueryLogHook.create();
private SpecificSegmentsQuerySegmentWalker walker;
private SqlLifecycleFactory sqlLifecycleFactory;
@Before
public void setUp() throws Exception
@Override
public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker() throws IOException
{
ApproximateHistogramDruidModule.registerSerde();
for (Module mod : new ApproximateHistogramDruidModule().getJacksonModules()) {
CalciteTests.getJsonMapper().registerModule(mod);
TestHelper.JSON_MAPPER.registerModule(mod);
}
final QueryableIndex index = IndexBuilder.create()
final QueryableIndex index = IndexBuilder.create(CalciteTests.getJsonMapper())
.tmpDir(temporaryFolder.newFolder())
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(
@ -148,9 +114,9 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
.rows(CalciteTests.ROWS1)
.buildMMappedIndex();
walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
return new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
DataSegment.builder()
.dataSource(DATA_SOURCE)
.dataSource(CalciteTests.DATASOURCE1)
.interval(index.getDataInterval())
.version("1")
.shardSpec(new LinearShardSpec(0))
@ -158,40 +124,45 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
.build(),
index
);
}
final PlannerConfig plannerConfig = new PlannerConfig();
final DruidOperatorTable operatorTable = new DruidOperatorTable(
ImmutableSet.of(new QuantileSqlAggregator()),
ImmutableSet.of()
);
SchemaPlus rootSchema =
CalciteTests.createMockRootSchema(conglomerate, walker, plannerConfig, AuthTestUtils.TEST_AUTHORIZER_MAPPER);
sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(
new PlannerFactory(
rootSchema,
CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
operatorTable,
CalciteTests.createExprMacroTable(),
plannerConfig,
AuthTestUtils.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper(),
CalciteTests.DRUID_SCHEMA_NAME
)
@Override
public List<Object[]> getResults(
final PlannerConfig plannerConfig,
final Map<String, Object> queryContext,
final List<SqlParameter> parameters,
final String sql,
final AuthenticationResult authenticationResult
) throws Exception
{
return getResults(
plannerConfig,
queryContext,
parameters,
sql,
authenticationResult,
OPERATOR_TABLE,
CalciteTests.createExprMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
);
}
@After
public void tearDown() throws Exception
private SqlLifecycle getSqlLifecycle()
{
walker.close();
walker = null;
return getSqlLifecycleFactory(
BaseCalciteQueryTest.PLANNER_CONFIG_DEFAULT,
OPERATOR_TABLE,
CalciteTests.createExprMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
).factorize();
}
@Test
public void testQuantileOnFloatAndLongs() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "APPROX_QUANTILE(m1, 0.01),\n"
@ -208,9 +179,9 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
// Verify results
final List<Object[]> results = sqlLifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
@ -269,7 +240,7 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
new QuantilePostAggregator("a7", "a5:agg", 0.999f),
new QuantilePostAggregator("a8", "a8:agg", 0.50f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
@ -278,7 +249,7 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
@Test
public void testQuantileOnComplexColumn() throws Exception
{
SqlLifecycle lifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle lifecycle = getSqlLifecycle();
final String sql = "SELECT\n"
+ "APPROX_QUANTILE(hist_m1, 0.01),\n"
+ "APPROX_QUANTILE(hist_m1, 0.5, 50),\n"
@ -292,9 +263,9 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
// Verify results
final List<Object[]> results = lifecycle.runSimple(
sql,
QUERY_CONTEXT_DEFAULT,
TIMESERIES_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{1.0, 3.0, 5.880000114440918, 5.940000057220459, 6.0, 4.994999885559082, 6.0}
@ -331,7 +302,7 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
new QuantilePostAggregator("a5", "a5:agg", 0.999f),
new QuantilePostAggregator("a6", "a4:agg", 0.999f)
)
.context(ImmutableMap.of("skipEmptyBuckets", true, PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
@ -340,7 +311,7 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
@Test
public void testQuantileOnInnerQuery() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
SqlLifecycle sqlLifecycle = getSqlLifecycle();
final String sql = "SELECT AVG(x), APPROX_QUANTILE(x, 0.98)\n"
+ "FROM (SELECT dim2, SUM(m1) AS x FROM foo GROUP BY dim2)";
@ -349,7 +320,7 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
sql,
QUERY_CONTEXT_DEFAULT,
DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
final List<Object[]> expectedResults;
if (NullHandling.replaceWithDefault()) {
@ -377,7 +348,7 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
new DoubleSumAggregatorFactory("a0", "m1")
)
)
.setContext(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
)
)
@ -409,9 +380,72 @@ public class QuantileSqlAggregatorTest extends CalciteTestBase
new QuantilePostAggregator("_a1", "_a1:agg", 0.98f)
)
)
.setContext(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy"))
.setContext(QUERY_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@Test
public void testQuantileOnCastedString() throws Exception
{
cannotVectorize();
final List<Object[]> expectedResults;
if (NullHandling.replaceWithDefault()) {
expectedResults = ImmutableList.of(
new Object[]{"", 0.0d},
new Object[]{"a", 0.0d},
new Object[]{"b", 0.0d},
new Object[]{"c", 10.100000381469727d},
new Object[]{"d", 2.0d}
);
} else {
expectedResults = ImmutableList.of(
new Object[]{null, Double.NaN},
new Object[]{"", 1.0d},
new Object[]{"a", Double.NaN},
new Object[]{"b", 10.100000381469727d},
new Object[]{"c", 10.100000381469727d},
new Object[]{"d", 2.0d}
);
}
testQuery(
"SELECT dim3, APPROX_QUANTILE(CAST(dim1 as DOUBLE), 0.5) from foo group by dim3",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL)
.setVirtualColumns(
new ExpressionVirtualColumn(
"v0",
"CAST(\"dim1\", 'DOUBLE')",
ValueType.FLOAT,
ExprMacroTable.nil()
)
)
.setDimensions(new DefaultDimensionSpec("dim3", "d0"))
.setAggregatorSpecs(
new ApproximateHistogramAggregatorFactory(
"a0:agg",
"v0",
50,
7,
Float.NEGATIVE_INFINITY,
Float.POSITIVE_INFINITY,
false
)
)
.setPostAggregatorSpecs(
ImmutableList.of(
new QuantilePostAggregator("a0", "a0:agg", 0.5f)
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
expectedResults
);
}
}

View File

@ -39,6 +39,7 @@ import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
@ -71,7 +72,7 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
project,
aggregateCall.getArgList().get(0)
);
final DruidExpression input = Expressions.toDruidExpression(
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(
plannerContext,
rowSignature,
inputOperand

View File

@ -19,7 +19,9 @@
package org.apache.druid.query.aggregation.variance;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryPlus;
@ -31,13 +33,16 @@ import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesQueryRunnerTest;
import org.apache.druid.query.timeseries.TimeseriesResultValue;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
@ -118,6 +123,39 @@ public class VarianceTimeseriesQueryTest extends InitializedNullHandlingTest
assertExpectedResults(expectedResults, results);
}
@Test
public void testEmptyTimeseries()
{
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
.granularity(QueryRunnerTestHelper.ALL_GRAN)
.intervals(QueryRunnerTestHelper.EMPTY_INTERVAL)
.aggregators(
Arrays.asList(
QueryRunnerTestHelper.ROWS_COUNT,
QueryRunnerTestHelper.INDEX_DOUBLE_SUM,
new VarianceAggregatorFactory("variance", "index", null, null)
)
)
.descending(true)
.context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
.build();
Map<String, Object> resultMap = new HashMap<>();
resultMap.put("rows", 0L);
resultMap.put("index", NullHandling.defaultDoubleValue());
resultMap.put("variance", NullHandling.defaultDoubleValue());
List<Result<TimeseriesResultValue>> expectedResults = ImmutableList.of(
new Result<>(
DateTimes.of("2020-04-02"),
new TimeseriesResultValue(
resultMap
)
)
);
Iterable<Result<TimeseriesResultValue>> actualResults = runner.run(QueryPlus.wrap(query)).toList();
TestHelper.assertExpectedResults(expectedResults, actualResults);
}
private <T> void assertExpectedResults(Iterable<Result<T>> expectedResults, Iterable<Result<T>> results)
{
if (descending) {

View File

@ -22,29 +22,16 @@ package org.apache.druid.query.aggregation.variance.sql;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.data.input.impl.DimensionSchema;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.data.input.impl.DoubleDimensionSchema;
import org.apache.druid.data.input.impl.FloatDimensionSchema;
import org.apache.druid.data.input.impl.InputRowParser;
import org.apache.druid.data.input.impl.LongDimensionSchema;
import org.apache.druid.data.input.impl.MapInputRowParser;
import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
import org.apache.druid.data.input.impl.TimestampSpec;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.QueryRunnerTestHelper;
import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
@ -56,119 +43,79 @@ import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesQueryEngine;
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import org.apache.druid.query.timeseries.TimeseriesQueryRunnerFactory;
import org.apache.druid.query.timeseries.TimeseriesResultValue;
import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AuthTestUtils;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.sql.SqlLifecycle;
import org.apache.druid.sql.SqlLifecycleFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerFactory;
import org.apache.druid.sql.calcite.util.CalciteTestBase;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.QueryLogHook;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.apache.druid.sql.http.SqlParameter;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RunWith(JUnitParamsRunner.class)
public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
public class VarianceSqlAggregatorTest extends BaseCalciteQueryTest
{
private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final AuthenticationResult AUTH_RESULT = CalciteTests.REGULAR_USER_AUTH_RESULT;
private static final String DATA_SOURCE = "numfoo";
private static final DruidOperatorTable OPERATOR_TABLE = new DruidOperatorTable(
ImmutableSet.of(
new BaseVarianceSqlAggregator.VarPopSqlAggregator(),
new BaseVarianceSqlAggregator.VarSampSqlAggregator(),
new BaseVarianceSqlAggregator.VarianceSqlAggregator(),
new BaseVarianceSqlAggregator.StdDevPopSqlAggregator(),
new BaseVarianceSqlAggregator.StdDevSampSqlAggregator(),
new BaseVarianceSqlAggregator.StdDevSqlAggregator()
),
ImmutableSet.of()
);
private static QueryRunnerFactoryConglomerate conglomerate;
private static Closer resourceCloser;
private SqlLifecycle sqlLifecycle;
@BeforeClass
public static void setUpClass()
@Override
public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker() throws IOException
{
resourceCloser = Closer.create();
conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser);
}
@AfterClass
public static void tearDownClass() throws IOException
{
resourceCloser.close();
}
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule
public QueryLogHook queryLogHook = QueryLogHook.create();
private SpecificSegmentsQuerySegmentWalker walker;
private SqlLifecycleFactory sqlLifecycleFactory;
@Before
public void setUp() throws Exception
{
InputRowParser parser = new MapInputRowParser(
new TimeAndDimsParseSpec(
new TimestampSpec("t", "iso", null),
new DimensionsSpec(
ImmutableList.<DimensionSchema>builder()
.addAll(DimensionsSpec.getDefaultSchemas(ImmutableList.of("dim1", "dim2", "dim3")))
.add(new DoubleDimensionSchema("d1"))
.add(new FloatDimensionSchema("f1"))
.add(new LongDimensionSchema("l1"))
.build(),
null,
null
)
));
final QueryableIndex index =
IndexBuilder.create()
.tmpDir(temporaryFolder.newFolder())
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(
new IncrementalIndexSchema.Builder()
.withDimensionsSpec(
new DimensionsSpec(
ImmutableList.<DimensionSchema>builder()
.addAll(DimensionsSpec.getDefaultSchemas(ImmutableList.of("dim1", "dim2", "dim3")))
.add(new DoubleDimensionSchema("d1"))
.add(new FloatDimensionSchema("f1"))
.add(new LongDimensionSchema("l1"))
.build(),
null,
null
)
)
.withMetrics(
new CountAggregatorFactory("cnt"),
new DoubleSumAggregatorFactory("m1", "m1")
)
.withDimensionsSpec(parser)
.withRollup(false)
.build()
)
.rows(CalciteTests.ROWS1_WITH_NUMERIC_DIMS)
.buildMMappedIndex();
walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
return new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
DataSegment.builder()
.dataSource(DATA_SOURCE)
.interval(index.getDataInterval())
@ -178,43 +125,39 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
.build(),
index
);
final PlannerConfig plannerConfig = new PlannerConfig();
final DruidOperatorTable operatorTable = new DruidOperatorTable(
ImmutableSet.of(
new BaseVarianceSqlAggregator.VarPopSqlAggregator(),
new BaseVarianceSqlAggregator.VarSampSqlAggregator(),
new BaseVarianceSqlAggregator.VarianceSqlAggregator(),
new BaseVarianceSqlAggregator.StdDevPopSqlAggregator(),
new BaseVarianceSqlAggregator.StdDevSampSqlAggregator(),
new BaseVarianceSqlAggregator.StdDevSqlAggregator()
),
ImmutableSet.of()
);
SchemaPlus rootSchema =
CalciteTests.createMockRootSchema(conglomerate, walker, plannerConfig, AuthTestUtils.TEST_AUTHORIZER_MAPPER);
sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(
new PlannerFactory(
rootSchema,
CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
operatorTable,
CalciteTests.createExprMacroTable(),
plannerConfig,
AuthTestUtils.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper(),
CalciteTests.DRUID_SCHEMA_NAME
)
);
queryLogHook.clearRecordedQueries();
sqlLifecycle = sqlLifecycleFactory.factorize();
}
@After
public void tearDown() throws Exception
@Override
public List<Object[]> getResults(
final PlannerConfig plannerConfig,
final Map<String, Object> queryContext,
final List<SqlParameter> parameters,
final String sql,
final AuthenticationResult authenticationResult
) throws Exception
{
walker.close();
walker = null;
return getResults(
plannerConfig,
queryContext,
parameters,
sql,
authenticationResult,
OPERATOR_TABLE,
CalciteTests.createExprMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
);
}
private SqlLifecycle getSqlLifecycle()
{
return getSqlLifecycleFactory(
BaseCalciteQueryTest.PLANNER_CONFIG_DEFAULT,
OPERATOR_TABLE,
CalciteTests.createExprMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
CalciteTests.getJsonMapper()
).factorize();
}
public void addToHolder(VarianceAggregatorCollector holder, Object raw)
@ -255,11 +198,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
+ "FROM numfoo";
final List<Object[]> results =
sqlLifecycle.runSimple(
getSqlLifecycle().runSimple(
sql,
BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT,
CalciteTestBase.DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
@ -281,7 +224,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
holder3.getVariance(true).longValue()
}
);
assertResultsEquals(expectedResults, results);
assertResultsEquals(sql, expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
@ -311,11 +254,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
+ "FROM numfoo";
final List<Object[]> results =
sqlLifecycle.runSimple(
getSqlLifecycle().runSimple(
sql,
BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT,
CalciteTestBase.DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
@ -337,7 +280,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
holder3.getVariance(false).longValue(),
}
);
assertResultsEquals(expectedResults, results);
assertResultsEquals(sql, expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
@ -367,11 +310,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
+ "FROM numfoo";
final List<Object[]> results =
sqlLifecycle.runSimple(
getSqlLifecycle().runSimple(
sql,
BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT,
CalciteTestBase.DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
@ -393,7 +336,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
(long) Math.sqrt(holder3.getVariance(true)),
}
);
assertResultsEquals(expectedResults, results);
assertResultsEquals(sql, expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
@ -430,11 +373,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
+ "FROM numfoo";
final List<Object[]> results =
sqlLifecycle.runSimple(
getSqlLifecycle().runSimple(
sql,
BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT,
CalciteTestBase.DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
@ -456,7 +399,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
(long) Math.sqrt(holder3.getVariance(false)),
}
);
assertResultsEquals(expectedResults, results);
assertResultsEquals(sql, expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
@ -491,11 +434,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
+ "FROM numfoo";
final List<Object[]> results =
sqlLifecycle.runSimple(
getSqlLifecycle().runSimple(
sql,
BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT,
CalciteTestBase.DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
@ -517,7 +460,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
(long) Math.sqrt(holder3.getVariance(false)),
}
);
assertResultsEquals(expectedResults, results);
assertResultsEquals(sql, expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
@ -554,11 +497,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
queryLogHook.clearRecordedQueries();
final String sql = "select dim2, VARIANCE(f1) from druid.numfoo group by 1 order by 2 desc";
final List<Object[]> results =
sqlLifecycle.runSimple(
getSqlLifecycle().runSimple(
sql,
BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT,
CalciteTestBase.DEFAULT_PARAMETERS,
authenticationResult
AUTH_RESULT
).toList();
List<Object[]> expectedResults = NullHandling.sqlCompatible()
? ImmutableList.of(
@ -571,7 +514,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
new Object[]{"", 0.0033333334f},
new Object[]{"abc", 0f}
);
assertResultsEquals(expectedResults, results);
assertResultsEquals(sql, expectedResults, results);
Assert.assertEquals(
GroupByQuery.builder()
@ -600,52 +543,48 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
);
}
public Object[] timeseriesQueryRunners()
{
return QueryRunnerTestHelper.makeQueryRunners(
new TimeseriesQueryRunnerFactory(
new TimeseriesQueryQueryToolChest(),
new TimeseriesQueryEngine(),
QueryRunnerTestHelper.NOOP_QUERYWATCHER
)
).toArray();
}
@Test
@Parameters(method = "timeseriesQueryRunners")
public void testEmptyTimeseries(QueryRunner<Result<TimeseriesResultValue>> runner)
public void testVariancesOnCastedString() throws Exception
{
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
.granularity(QueryRunnerTestHelper.ALL_GRAN)
.intervals(QueryRunnerTestHelper.EMPTY_INTERVAL)
.aggregators(
Arrays.asList(
QueryRunnerTestHelper.ROWS_COUNT,
QueryRunnerTestHelper.INDEX_DOUBLE_SUM,
new VarianceAggregatorFactory("variance", "index", null, null)
)
)
.descending(true)
.context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
.build();
Map<String, Object> resultMap = new HashMap<>();
resultMap.put("rows", 0L);
resultMap.put("index", NullHandling.defaultDoubleValue());
resultMap.put("variance", NullHandling.defaultDoubleValue());
List<Result<TimeseriesResultValue>> expectedResults = ImmutableList.of(
new Result<>(
DateTimes.of("2020-04-02"),
new TimeseriesResultValue(
resultMap
testQuery(
"SELECT\n"
+ "STDDEV_POP(CAST(dim1 AS DOUBLE)),\n"
+ "STDDEV_SAMP(CAST(dim1 AS DOUBLE)),\n"
+ "STDDEV(CAST(dim1 AS DOUBLE)),\n"
+ "VARIANCE(CAST(dim1 AS DOUBLE))\n"
+ "FROM numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(DATA_SOURCE)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(
new ExpressionVirtualColumn("v0", "CAST(\"dim1\", 'DOUBLE')", ValueType.DOUBLE, ExprMacroTable.nil())
)
.granularity(Granularities.ALL)
.aggregators(
new VarianceAggregatorFactory("a0:agg", "v0", "population", "double"),
new VarianceAggregatorFactory("a1:agg", "v0", "sample", "double"),
new VarianceAggregatorFactory("a2:agg", "v0", "sample", "double"),
new VarianceAggregatorFactory("a3:agg", "v0", "sample", "double")
)
.postAggregators(
new StandardDeviationPostAggregator("a0", "a0:agg", "population"),
new StandardDeviationPostAggregator("a1", "a1:agg", "sample"),
new StandardDeviationPostAggregator("a2", "a2:agg", "sample")
)
.context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
NullHandling.replaceWithDefault()
? new Object[]{3.61497656362466, 3.960008417499471, 3.960008417499471, 15.681666666666667}
: new Object[]{4.074582459862878, 4.990323970779185, 4.990323970779185, 24.903333333333332}
)
);
Iterable<Result<TimeseriesResultValue>> actualResults = runner.run(QueryPlus.wrap(query)).toList();
TestHelper.assertExpectedResults(expectedResults, actualResults);
}
private static void assertResultsEquals(List<Object[]> expectedResults, List<Object[]> results)
@Override
public void assertResultsEquals(String sql, List<Object[]> expectedResults, List<Object[]> results)
{
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {

View File

@ -19,6 +19,7 @@
package org.apache.druid.segment;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
@ -57,21 +58,40 @@ public class IndexBuilder
.withMetrics(new CountAggregatorFactory("count"))
.build();
private SegmentWriteOutMediumFactory segmentWriteOutMediumFactory = OffHeapMemorySegmentWriteOutMediumFactory.instance();
private IndexMerger indexMerger = TestHelper.getTestIndexMergerV9(segmentWriteOutMediumFactory);
private IndexMerger indexMerger;
private File tmpDir;
private IndexSpec indexSpec = new IndexSpec();
private int maxRows = DEFAULT_MAX_ROWS;
private final ObjectMapper jsonMapper;
private final IndexIO indexIO;
private final List<InputRow> rows = new ArrayList<>();
private IndexBuilder()
private IndexBuilder(ObjectMapper jsonMapper, ColumnConfig columnConfig)
{
// Callers must use "create".
this.jsonMapper = jsonMapper;
this.indexIO = new IndexIO(jsonMapper, columnConfig);
this.indexMerger = new IndexMergerV9(jsonMapper, indexIO, segmentWriteOutMediumFactory);
}
public static IndexBuilder create()
{
return new IndexBuilder();
return new IndexBuilder(TestHelper.JSON_MAPPER, TestHelper.NO_CACHE_COLUMN_CONFIG);
}
public static IndexBuilder create(ColumnConfig columnConfig)
{
return new IndexBuilder(TestHelper.JSON_MAPPER, columnConfig);
}
public static IndexBuilder create(ObjectMapper jsonMapper)
{
return new IndexBuilder(jsonMapper, TestHelper.NO_CACHE_COLUMN_CONFIG);
}
public static IndexBuilder create(ObjectMapper jsonMapper, ColumnConfig columnConfig)
{
return new IndexBuilder(jsonMapper, columnConfig);
}
public IndexBuilder schema(IncrementalIndexSchema schema)
@ -83,7 +103,7 @@ public class IndexBuilder
public IndexBuilder segmentWriteOutMediumFactory(SegmentWriteOutMediumFactory segmentWriteOutMediumFactory)
{
this.segmentWriteOutMediumFactory = segmentWriteOutMediumFactory;
this.indexMerger = TestHelper.getTestIndexMergerV9(segmentWriteOutMediumFactory);
this.indexMerger = new IndexMergerV9(jsonMapper, indexIO, segmentWriteOutMediumFactory);
return this;
}
@ -112,17 +132,11 @@ public class IndexBuilder
}
public QueryableIndex buildMMappedIndex()
{
ColumnConfig noCacheColumnConfig = () -> 0;
return buildMMappedIndex(noCacheColumnConfig);
}
public QueryableIndex buildMMappedIndex(ColumnConfig columnConfig)
{
Preconditions.checkNotNull(indexMerger, "indexMerger");
Preconditions.checkNotNull(tmpDir, "tmpDir");
try (final IncrementalIndex incrementalIndex = buildIncrementalIndex()) {
return TestHelper.getTestIndexIO(columnConfig).loadIndex(
return indexIO.loadIndex(
indexMerger.persist(
incrementalIndex,
new File(

View File

@ -58,6 +58,7 @@ import java.util.stream.IntStream;
public class TestHelper
{
public static final ObjectMapper JSON_MAPPER = makeJsonMapper();
public static final ColumnConfig NO_CACHE_COLUMN_CONFIG = () -> 0;
public static IndexMergerV9 getTestIndexMergerV9(SegmentWriteOutMediumFactory segmentWriteOutMediumFactory)
{
@ -66,8 +67,7 @@ public class TestHelper
public static IndexIO getTestIndexIO()
{
ColumnConfig noCacheColumnConfig = () -> 0;
return getTestIndexIO(noCacheColumnConfig);
return getTestIndexIO(NO_CACHE_COLUMN_CONFIG);
}
public static IndexIO getTestIndexIO(ColumnConfig columnConfig)

View File

@ -52,6 +52,7 @@ import org.apache.druid.segment.DimensionSelector;
import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.RowAdapter;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnConfig;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
@ -171,15 +172,19 @@ public class JoinTestHelper
public static IndexBuilder createFactIndexBuilder(final File tmpDir) throws IOException
{
return createFactIndexBuilder(tmpDir, -1);
return createFactIndexBuilder(TestHelper.NO_CACHE_COLUMN_CONFIG, tmpDir, -1);
}
public static IndexBuilder createFactIndexBuilder(final File tmpDir, final int numRows) throws IOException
public static IndexBuilder createFactIndexBuilder(
final ColumnConfig columnConfig,
final File tmpDir,
final int numRows
) throws IOException
{
return withRowsFromResource(
"/wikipedia/data.json",
rows -> IndexBuilder
.create()
.create(columnConfig)
.tmpDir(tmpDir)
.schema(
new IncrementalIndexSchema.Builder()

View File

@ -67,7 +67,7 @@ public class Aggregations
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rowSignature, project, i))
.map(rexNode -> toDruidExpressionForSimpleAggregator(plannerContext, rowSignature, rexNode))
.map(rexNode -> toDruidExpressionForNumericAggregator(plannerContext, rowSignature, rexNode))
.collect(Collectors.toList());
if (args.stream().noneMatch(Objects::isNull)) {
@ -77,7 +77,21 @@ public class Aggregations
}
}
private static DruidExpression toDruidExpressionForSimpleAggregator(
/**
* Translate a Calcite {@link RexNode} to a Druid expression for the aggregators that require numeric type inputs.
* The returned expression can keep an explicit cast from strings to numbers when the column consumed by
* the expression is the string type.
*
* Consider using {@link Expressions#toDruidExpression(PlannerContext, RowSignature, RexNode)} for projections
* or the aggregators that don't require numeric inputs.
*
* @param plannerContext SQL planner context
* @param rowSignature signature of the rows to be extracted from
* @param rexNode expression meant to be applied on top of the rows
*
* @return DruidExpression referring to fields in rowOrder, or null if not possible to translate
*/
public static DruidExpression toDruidExpressionForNumericAggregator(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final RexNode rexNode

View File

@ -166,13 +166,17 @@ public class Expressions
}
/**
* Translate a Calcite {@code RexNode} to a Druid expressions.
* Translate a Calcite {@link RexNode} to a Druid expression for projections or the aggregators that don't
* require numeric inputs.
*
* Consider using {@link org.apache.druid.sql.calcite.aggregation.Aggregations#toDruidExpressionForNumericAggregator}
* for the aggregators that require numeric inputs.
*
* @param plannerContext SQL planner context
* @param rowSignature signature of the rows to be extracted from
* @param rexNode expression meant to be applied on top of the rows
*
* @return rexNode referring to fields in rowOrder, or null if not possible
* @return DruidExpression referring to fields in rowOrder, or null if not possible to translate
*/
@Nullable
public static DruidExpression toDruidExpression(

View File

@ -439,10 +439,7 @@ public class BaseCalciteQueryTest extends CalciteTestBase
@Before
public void setUp() throws Exception
{
walker = CalciteTests.createMockWalker(
conglomerate,
temporaryFolder.newFolder()
);
walker = createQuerySegmentWalker();
}
@After
@ -452,6 +449,14 @@ public class BaseCalciteQueryTest extends CalciteTestBase
walker = null;
}
public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker() throws IOException
{
return CalciteTests.createMockWalker(
conglomerate,
temporaryFolder.newFolder()
);
}
public void assertQueryIsUnplannable(final String sql)
{
assertQueryIsUnplannable(PLANNER_CONFIG_DEFAULT, sql);
@ -718,13 +723,7 @@ public class BaseCalciteQueryTest extends CalciteTestBase
}
Assert.assertEquals(StringUtils.format("result count: %s", sql), expectedResults.size(), results.size());
for (int i = 0; i < results.size(); i++) {
Assert.assertArrayEquals(
StringUtils.format("result #%d: %s", i + 1, sql),
expectedResults.get(i),
results.get(i)
);
}
assertResultsEquals(sql, expectedResults, results);
if (expectedQueries != null) {
final List<Query> recordedQueries = queryLogHook.getRecordedQueries();
@ -744,6 +743,17 @@ public class BaseCalciteQueryTest extends CalciteTestBase
}
}
public void assertResultsEquals(String sql, List<Object[]> expectedResults, List<Object[]> results)
{
for (int i = 0; i < results.size(); i++) {
Assert.assertArrayEquals(
StringUtils.format("result #%d: %s", i + 1, sql),
expectedResults.get(i),
results.get(i)
);
}
}
public Set<Resource> analyzeResources(
PlannerConfig plannerConfig,
String sql,