diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlNestedDataBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlNestedDataBenchmark.java new file mode 100644 index 00000000000..d41589142d7 --- /dev/null +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlNestedDataBenchmark.java @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.benchmark.query; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.data.input.impl.DimensionSchema; +import org.apache.druid.data.input.impl.DimensionsSpec; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.math.expr.ExpressionProcessing; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.QueryContexts; +import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.expression.TestExprMacroTable; +import org.apache.druid.segment.NestedDataDimensionSchema; +import org.apache.druid.segment.QueryableIndex; +import org.apache.druid.segment.generator.GeneratorBasicSchemas; +import org.apache.druid.segment.generator.GeneratorSchemaInfo; +import org.apache.druid.segment.generator.SegmentGenerator; +import org.apache.druid.segment.transform.ExpressionTransform; +import org.apache.druid.segment.transform.TransformSpec; +import org.apache.druid.server.QueryStackTests; +import org.apache.druid.server.security.AuthTestUtils; +import org.apache.druid.sql.calcite.SqlVectorizedExpressionSanityTest; +import org.apache.druid.sql.calcite.planner.CalciteRulesManager; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.DruidPlanner; +import org.apache.druid.sql.calcite.planner.PlannerConfig; +import org.apache.druid.sql.calcite.planner.PlannerFactory; +import org.apache.druid.sql.calcite.planner.PlannerResult; +import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; +import org.apache.druid.sql.calcite.util.CalciteTests; +import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.partition.LinearShardSpec; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +public class SqlNestedDataBenchmark +{ + private static final Logger log = new Logger(SqlNestedDataBenchmark.class); + + static { + NullHandling.initializeForTests(); + Calcites.setSystemProperties(); + ExpressionProcessing.initializeForStrictBooleansTests(true); + } + + private static final DruidProcessingConfig PROCESSING_CONFIG = new DruidProcessingConfig() + { + @Override + public int intermediateComputeSizeBytes() + { + return 512 * 1024 * 1024; + } + + @Override + public int getNumMergeBuffers() + { + return 3; + } + + @Override + public int getNumThreads() + { + return 1; + } + + @Override + public boolean useParallelMergePoolConfigured() + { + return true; + } + + @Override + public String getFormatString() + { + return "benchmarks-processing-%s"; + } + }; + + + private static final List QUERIES = ImmutableList.of( + // =========================== + // non-nested reference queries + // =========================== + // 0,1: timeseries, 1 columns + "SELECT SUM(long1) FROM foo", + "SELECT SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)) FROM foo", + // 2,3: timeseries, 2 columns + "SELECT SUM(long1), SUM(long2) FROM foo", + "SELECT SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)), SUM(JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT)) FROM foo", + // 4,5: timeseries, 3 columns + "SELECT SUM(long1), SUM(long2), SUM(double3) FROM foo", + "SELECT SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)), SUM(JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT)), SUM(JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE)) FROM foo", + // 6,7: group by string with 1 agg + "SELECT string1, SUM(long1) FROM foo GROUP BY 1 ORDER BY 2", + "SELECT JSON_VALUE(nested, '$.nesteder.string1'), SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)) FROM foo GROUP BY 1 ORDER BY 2", + // 8,9: group by string with 2 agg + "SELECT string1, SUM(long1), SUM(double3) FROM foo GROUP BY 1 ORDER BY 2", + "SELECT JSON_VALUE(nested, '$.nesteder.string1'), SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)), SUM(JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE)) FROM foo GROUP BY 1 ORDER BY 2", + // 10,11: time-series filter string + "SELECT SUM(long1) FROM foo WHERE string1 = '10000' OR string1 = '1000'", + "SELECT SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.string1') = '10000' OR JSON_VALUE(nested, '$.nesteder.string1') = '1000'", + // 12,13: time-series filter long + "SELECT SUM(long1) FROM foo WHERE long2 = 10000 OR long2 = 1000", + "SELECT SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) = 10000 OR JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) = 1000", + // 14,15: time-series filter double + "SELECT SUM(long1) FROM foo WHERE double3 < 10000.0 AND double3 > 1000.0", + "SELECT SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE) < 10000.0 AND JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE) > 1000.0", + // 16,17: group by long filter by string + "SELECT long1, SUM(double3) FROM foo WHERE string1 = '10000' OR string1 = '1000' GROUP BY 1 ORDER BY 2", + "SELECT JSON_VALUE(nested, '$.long1' RETURNING BIGINT), SUM(JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE)) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.string1') = '10000' OR JSON_VALUE(nested, '$.nesteder.string1') = '1000' GROUP BY 1 ORDER BY 2", + // 18,19: group by string filter by long + "SELECT string1, SUM(double3) FROM foo WHERE long2 < 10000 AND long2 > 1000 GROUP BY 1 ORDER BY 2", + "SELECT JSON_VALUE(nested, '$.nesteder.string1'), SUM(JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE)) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) < 10000 AND JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) > 1000 GROUP BY 1 ORDER BY 2", + // 20,21: group by string filter by double + "SELECT string1, SUM(double3) FROM foo WHERE double3 < 10000.0 AND double3 > 1000.0 GROUP BY 1 ORDER BY 2", + "SELECT JSON_VALUE(nested, '$.nesteder.string1'), SUM(JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE)) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE) < 10000.0 AND JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE) > 1000.0 GROUP BY 1 ORDER BY 2", + // 22, 23: + "SELECT long2 FROM foo WHERE long2 IN (1, 19, 21, 23, 25, 26, 46)", + "SELECT JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) IN (1, 19, 21, 23, 25, 26, 46)", + // 24, 25 + "SELECT long2 FROM foo WHERE long2 IN (1, 19, 21, 23, 25, 26, 46) GROUP BY 1", + "SELECT JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) IN (1, 19, 21, 23, 25, 26, 46) GROUP BY 1" + ); + + @Param({"5000000"}) + private int rowsPerSegment; + + @Param({ + "false", + "force" + }) + private String vectorize; + + @Param({ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + "12", + "13", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25" + }) + private String query; + + @Nullable + private PlannerFactory plannerFactory; + private Closer closer = Closer.create(); + + @Setup(Level.Trial) + public void setup() + { + final GeneratorSchemaInfo schemaInfo = GeneratorBasicSchemas.SCHEMA_MAP.get("expression-testbench"); + + final DataSegment dataSegment = DataSegment.builder() + .dataSource("foo") + .interval(schemaInfo.getDataInterval()) + .version("1") + .shardSpec(new LinearShardSpec(0)) + .size(0) + .build(); + + + + final PlannerConfig plannerConfig = new PlannerConfig(); + + final SegmentGenerator segmentGenerator = closer.register(new SegmentGenerator()); + log.info("Starting benchmark setup using cacheDir[%s], rows[%,d].", segmentGenerator.getCacheDir(), rowsPerSegment); + + TransformSpec transformSpec = new TransformSpec( + null, + ImmutableList.of( + new ExpressionTransform( + "nested", + "json_object('long1', long1, 'nesteder', json_object('string1', string1, 'long2', long2, 'double3',double3))", + TestExprMacroTable.INSTANCE + ) + ) + ); + List dims = ImmutableList.builder() + .addAll(schemaInfo.getDimensionsSpec().getDimensions()) + .add(new NestedDataDimensionSchema("nested")) + .build(); + DimensionsSpec dimsSpec = new DimensionsSpec(dims); + final QueryableIndex index = segmentGenerator.generate( + dataSegment, + schemaInfo, + dimsSpec, + transformSpec, + Granularities.NONE, + rowsPerSegment + ); + + final QueryRunnerFactoryConglomerate conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate( + closer, + PROCESSING_CONFIG + ); + + final SpecificSegmentsQuerySegmentWalker walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add( + dataSegment, + index + ); + closer.register(walker); + + final DruidSchemaCatalog rootSchema = + CalciteTests.createMockRootSchema(conglomerate, walker, plannerConfig, AuthTestUtils.TEST_AUTHORIZER_MAPPER); + plannerFactory = new PlannerFactory( + rootSchema, + CalciteTests.createMockQueryMakerFactory(walker, conglomerate), + CalciteTests.createOperatorTable(), + CalciteTests.createExprMacroTable(), + plannerConfig, + AuthTestUtils.TEST_AUTHORIZER_MAPPER, + CalciteTests.getJsonMapper(), + CalciteTests.DRUID_SCHEMA_NAME, + new CalciteRulesManager(ImmutableSet.of()) + ); + + try { + SqlVectorizedExpressionSanityTest.sanityTestVectorizedSqlQueries( + plannerFactory, + QUERIES.get(Integer.parseInt(query)) + ); + } + catch (Throwable ignored) { + // the show must go on + } + } + + @TearDown(Level.Trial) + public void tearDown() throws Exception + { + closer.close(); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.MILLISECONDS) + public void querySql(Blackhole blackhole) throws Exception + { + final Map context = ImmutableMap.of( + QueryContexts.VECTORIZE_KEY, vectorize, + QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize + ); + final String sql = QUERIES.get(Integer.parseInt(query)); + try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(context, sql)) { + final PlannerResult plannerResult = planner.plan(); + final Sequence resultSequence = plannerResult.run(); + final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); + blackhole.consume(lastRow); + } + } +} diff --git a/processing/src/test/java/org/apache/druid/segment/generator/SegmentGenerator.java b/processing/src/test/java/org/apache/druid/segment/generator/SegmentGenerator.java index 4f21658aaa9..1856325dae5 100644 --- a/processing/src/test/java/org/apache/druid/segment/generator/SegmentGenerator.java +++ b/processing/src/test/java/org/apache/druid/segment/generator/SegmentGenerator.java @@ -22,6 +22,9 @@ package org.apache.druid.segment.generator; import com.google.common.hash.Hashing; import org.apache.druid.common.config.NullHandling; import org.apache.druid.data.input.InputRow; +import org.apache.druid.data.input.MapBasedInputRow; +import org.apache.druid.data.input.impl.DimensionsSpec; +import org.apache.druid.guice.NestedDataModule; import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; @@ -38,6 +41,8 @@ import org.apache.druid.segment.data.RoaringBitmapSerdeFactory; import org.apache.druid.segment.incremental.IncrementalIndex; import org.apache.druid.segment.incremental.IncrementalIndexSchema; import org.apache.druid.segment.serde.ComplexMetrics; +import org.apache.druid.segment.transform.TransformSpec; +import org.apache.druid.segment.transform.Transformer; import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.SegmentId; @@ -48,7 +53,9 @@ import java.io.File; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class SegmentGenerator implements Closeable { @@ -98,6 +105,7 @@ public class SegmentGenerator implements Closeable return cacheDir; } + public QueryableIndex generate( final DataSegment dataSegment, final GeneratorSchemaInfo schemaInfo, @@ -105,14 +113,28 @@ public class SegmentGenerator implements Closeable final int numRows ) { - // In case we need to generate hyperUniques. + return generate(dataSegment, schemaInfo, schemaInfo.getDimensionsSpec(), TransformSpec.NONE, granularity, numRows); + } + + public QueryableIndex generate( + final DataSegment dataSegment, + final GeneratorSchemaInfo schemaInfo, + final DimensionsSpec dimensionsSpec, + final TransformSpec transformSpec, + final Granularity queryGranularity, + final int numRows + ) + { + // In case we need to generate hyperUniques or json ComplexMetrics.registerSerde("hyperUnique", new HyperUniquesSerde()); + NestedDataModule.registerHandlersAndSerde(); final String dataHash = Hashing.sha256() .newHasher() .putString(dataSegment.getId().toString(), StandardCharsets.UTF_8) .putString(schemaInfo.toString(), StandardCharsets.UTF_8) - .putString(granularity.toString(), StandardCharsets.UTF_8) + .putString(dimensionsSpec.toString(), StandardCharsets.UTF_8) + .putString(queryGranularity.toString(), StandardCharsets.UTF_8) .putInt(numRows) .hash() .toString(); @@ -139,18 +161,25 @@ public class SegmentGenerator implements Closeable ); final IncrementalIndexSchema indexSchema = new IncrementalIndexSchema.Builder() - .withDimensionsSpec(schemaInfo.getDimensionsSpec()) + .withDimensionsSpec(dimensionsSpec) .withMetrics(schemaInfo.getAggsArray()) .withRollup(schemaInfo.isWithRollup()) - .withQueryGranularity(granularity) + .withQueryGranularity(queryGranularity) .build(); final List rows = new ArrayList<>(); final List indexes = new ArrayList<>(); + Transformer transformer = transformSpec.toTransformer(); + for (int i = 0; i < numRows; i++) { - final InputRow row = dataGenerator.nextRow(); - rows.add(row); + final InputRow row = transformer.transform(dataGenerator.nextRow()); + Map evaluated = new HashMap<>(); + for (String dimension : dimensionsSpec.getDimensionNames()) { + evaluated.put(dimension, row.getRaw(dimension)); + } + MapBasedInputRow transformedRow = new MapBasedInputRow(row.getTimestamp(), dimensionsSpec.getDimensionNames(), evaluated); + rows.add(transformedRow); if ((i + 1) % 20000 == 0) { log.info("%,d/%,d rows generated for[%s].", i + 1, numRows, dataSegment);