Add benchmark suite for MSQ window functions (#17377)

* Add benchmark suite for MSQ window functions

* Fix inspection checks

* Address review comment: Rename method
This commit is contained in:
Akshat Jain 2024-10-30 11:32:28 +05:30 committed by GitHub
parent 63c91ad813
commit 21e7e5cddd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 335 additions and 9 deletions

View File

@ -47,6 +47,16 @@
<version>${jmh.version}</version> <version>${jmh.version}</version>
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency>
<groupId>org.reflections</groupId>
<artifactId>reflections</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.easymock</groupId> <groupId>org.easymock</groupId>
<artifactId>easymock</artifactId> <artifactId>easymock</artifactId>
@ -217,6 +227,13 @@
<version>${project.parent.version}</version> <version>${project.parent.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.apache.druid.extensions</groupId>
<artifactId>druid-multi-stage-query</artifactId>
<version>${project.parent.version}</version>
<classifier>tests</classifier>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<properties> <properties>

View File

@ -0,0 +1,224 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.benchmark.query;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Injector;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.sql.MSQTaskSqlEngine;
import org.apache.druid.msq.test.ExtractResultsFactory;
import org.apache.druid.msq.test.MSQTestOverlordServiceClient;
import org.apache.druid.msq.test.StandardMSQComponentSupplier;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.apache.druid.server.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.QueryTestBuilder;
import org.apache.druid.sql.calcite.SqlTestFrameworkConfig;
import org.apache.druid.sql.calcite.TempDirProducer;
import org.apache.druid.sql.calcite.util.TestDataBuilder;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.lang.annotation.Annotation;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* Benchmark that tests various SQL queries with window functions against MSQ engine.
*/
@State(Scope.Benchmark)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(value = 1)
@Warmup(iterations = 1)
@Measurement(iterations = 5)
@SqlTestFrameworkConfig.ComponentSupplier(MSQWindowFunctionsBenchmark.MSQComponentSupplier.class)
public class MSQWindowFunctionsBenchmark extends BaseCalciteQueryTest
{
static {
NullHandling.initializeForTests();
}
private static final Logger log = new Logger(MSQWindowFunctionsBenchmark.class);
private final Closer closer = Closer.create();
@Param({"20000000"})
private int rowsPerSegment;
@Param({"2", "5"})
private int maxNumTasks;
private List<Annotation> annotations;
@Setup(Level.Trial)
public void setup()
{
annotations = Arrays.asList(MSQWindowFunctionsBenchmark.class.getAnnotations());
// Populate the QueryableIndex for the benchmark datasource.
TestDataBuilder.makeQueryableIndexForBenchmarkDatasource(closer, rowsPerSegment);
}
@TearDown(Level.Trial)
public void tearDown() throws Exception
{
closer.close();
}
@Benchmark
public void windowWithoutGroupBy(Blackhole blackhole)
{
String sql = "SELECT ROW_NUMBER() "
+ "OVER (PARTITION BY dimUniform ORDER BY dimSequential) "
+ "FROM benchmark_ds";
querySql(sql, blackhole);
}
@Benchmark
public void windowWithoutSorting(Blackhole blackhole)
{
String sql = "SELECT dimZipf, dimSequential,"
+ "ROW_NUMBER() "
+ "OVER (PARTITION BY dimZipf) "
+ "from benchmark_ds\n"
+ "group by dimZipf, dimSequential";
querySql(sql, blackhole);
}
@Benchmark
public void windowWithSorting(Blackhole blackhole)
{
String sql = "SELECT dimZipf, dimSequential,"
+ "ROW_NUMBER() "
+ "OVER (PARTITION BY dimZipf ORDER BY dimSequential) "
+ "from benchmark_ds\n"
+ "group by dimZipf, dimSequential";
querySql(sql, blackhole);
}
@Benchmark
public void windowWithHighCardinalityPartitionBy(Blackhole blackhole)
{
String sql = "select\n"
+ "__time,\n"
+ "row_number() over (partition by __time) as c1\n"
+ "from benchmark_ds\n"
+ "group by __time";
querySql(sql, blackhole);
}
@Benchmark
public void windowWithLowCardinalityPartitionBy(Blackhole blackhole)
{
String sql = "select\n"
+ "dimZipf,\n"
+ "row_number() over (partition by dimZipf) as c1\n"
+ "from benchmark_ds\n"
+ "group by dimZipf";
querySql(sql, blackhole);
}
@Benchmark
public void multipleWindows(Blackhole blackhole)
{
String sql = "select\n"
+ "dimZipf, dimSequential, minFloatZipf,\n"
+ "row_number() over (partition by dimSequential order by minFloatZipf) as c1,\n"
+ "row_number() over (partition by dimZipf order by minFloatZipf) as c2,\n"
+ "row_number() over (partition by minFloatZipf order by minFloatZipf) as c3,\n"
+ "row_number() over (partition by dimSequential, dimZipf order by minFloatZipf, dimSequential) as c4,\n"
+ "row_number() over (partition by minFloatZipf, dimZipf order by dimSequential) as c5,\n"
+ "row_number() over (partition by minFloatZipf, dimSequential order by dimZipf) as c6,\n"
+ "row_number() over (partition by dimSequential, minFloatZipf, dimZipf order by dimZipf, minFloatZipf) as c7,\n"
+ "row_number() over (partition by dimSequential, minFloatZipf, dimZipf order by minFloatZipf) as c8\n"
+ "from benchmark_ds\n"
+ "group by dimZipf, dimSequential, minFloatZipf";
querySql(sql, blackhole);
}
public void querySql(String sql, Blackhole blackhole)
{
final Map<String, Object> context = ImmutableMap.of(
MultiStageQueryContext.CTX_MAX_NUM_TASKS, maxNumTasks
);
CalciteTestConfig calciteTestConfig = createCalciteTestConfig();
QueryTestBuilder queryTestBuilder = new QueryTestBuilder(calciteTestConfig)
.addCustomRunner(
new ExtractResultsFactory(() -> (MSQTestOverlordServiceClient) ((MSQTaskSqlEngine) queryFramework().engine()).overlordClient())
);
queryFrameworkRule.setConfig(new SqlTestFrameworkConfig(annotations));
final List<Object[]> resultList = queryTestBuilder
.skipVectorize(true)
.queryContext(context)
.sql(sql)
.results()
.results;
if (!resultList.isEmpty()) {
log.info("Total number of rows returned by query: %d", resultList.size());
Object[] lastRow = resultList.get(resultList.size() - 1);
blackhole.consume(lastRow);
} else {
log.info("No rows returned by the query.");
}
}
protected static class MSQComponentSupplier extends StandardMSQComponentSupplier
{
public MSQComponentSupplier(TempDirProducer tempFolderProducer)
{
super(tempFolderProducer);
}
@Override
public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker(
QueryRunnerFactoryConglomerate conglomerate,
JoinableFactoryWrapper joinableFactory,
Injector injector
)
{
final SpecificSegmentsQuerySegmentWalker retVal = super.createQuerySegmentWalker(
conglomerate,
joinableFactory,
injector);
TestDataBuilder.attachIndexesForBenchmarkDatasource(retVal);
return retVal;
}
}
}

View File

@ -352,6 +352,17 @@
</excludes> </excludes>
</configuration> </configuration>
</plugin> </plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>test-jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins> </plugins>
</build> </build>
</project> </project>

View File

@ -416,6 +416,9 @@ public class CalciteMSQTestsHelper
case CalciteTests.T_ALL_TYPE_PARQUET: case CalciteTests.T_ALL_TYPE_PARQUET:
index = TestDataBuilder.getQueryableIndexForDrillDatasource(segmentId.getDataSource(), tempFolderProducer.apply("tmpDir")); index = TestDataBuilder.getQueryableIndexForDrillDatasource(segmentId.getDataSource(), tempFolderProducer.apply("tmpDir"));
break; break;
case CalciteTests.BENCHMARK_DATASOURCE:
index = TestDataBuilder.getQueryableIndexForBenchmarkDatasource();
break;
default: default:
throw new ISE("Cannot query segment %s in test runner", segmentId); throw new ISE("Cannot query segment %s in test runner", segmentId);

View File

@ -23,7 +23,6 @@ import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.indexing.report.MSQResultsReport.ColumnAndType; import org.apache.druid.msq.indexing.report.MSQResultsReport.ColumnAndType;
import org.apache.druid.msq.indexing.report.MSQTaskReport; import org.apache.druid.msq.indexing.report.MSQTaskReport;
import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; import org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.QueryTestBuilder; import org.apache.druid.sql.calcite.QueryTestBuilder;
import org.apache.druid.sql.calcite.QueryTestRunner; import org.apache.druid.sql.calcite.QueryTestRunner;
import org.junit.Assert; import org.junit.Assert;
@ -57,7 +56,6 @@ public class ExtractResultsFactory implements QueryTestRunner.QueryRunStepFactor
return new QueryTestRunner.BaseExecuteQuery(builder) return new QueryTestRunner.BaseExecuteQuery(builder)
{ {
final List<QueryTestRunner.QueryResults> extractedResults = new ArrayList<>(); final List<QueryTestRunner.QueryResults> extractedResults = new ArrayList<>();
final RowSignature resultsSignature = null;
final MSQTestOverlordServiceClient overlordClient = overlordClientSupplier.get(); final MSQTestOverlordServiceClient overlordClient = overlordClientSupplier.get();

View File

@ -31,7 +31,7 @@ import org.apache.druid.sql.calcite.TempDirProducer;
import org.apache.druid.sql.calcite.run.SqlEngine; import org.apache.druid.sql.calcite.run.SqlEngine;
import org.apache.druid.sql.calcite.util.SqlTestFramework.StandardComponentSupplier; import org.apache.druid.sql.calcite.util.SqlTestFramework.StandardComponentSupplier;
public final class StandardMSQComponentSupplier extends StandardComponentSupplier public class StandardMSQComponentSupplier extends StandardComponentSupplier
{ {
public StandardMSQComponentSupplier(TempDirProducer tempFolderProducer) public StandardMSQComponentSupplier(TempDirProducer tempFolderProducer)
{ {

View File

@ -645,7 +645,7 @@ public class BaseCalciteQueryTest extends CalciteTestBase
} }
@RegisterExtension @RegisterExtension
static SqlTestFrameworkConfig.Rule queryFrameworkRule = new SqlTestFrameworkConfig.Rule(); protected static SqlTestFrameworkConfig.Rule queryFrameworkRule = new SqlTestFrameworkConfig.Rule();
public SqlTestFramework queryFramework() public SqlTestFramework queryFramework()
{ {
@ -896,6 +896,11 @@ public class BaseCalciteQueryTest extends CalciteTestBase
.skipVectorize(skipVectorize); .skipVectorize(skipVectorize);
} }
public CalciteTestConfig createCalciteTestConfig()
{
return new CalciteTestConfig();
}
public class CalciteTestConfig implements QueryTestBuilder.QueryTestConfig public class CalciteTestConfig implements QueryTestBuilder.QueryTestConfig
{ {
private boolean isRunningMSQ = false; private boolean isRunningMSQ = false;

View File

@ -772,8 +772,11 @@ public class QueryTestRunner
public QueryResults resultsOnly() public QueryResults resultsOnly()
{ {
ExecuteQuery execStep = (ExecuteQuery) runSteps.get(0); for (QueryRunStep runStep : runSteps) {
execStep.run(); runStep.run();
}
BaseExecuteQuery execStep = (BaseExecuteQuery) runSteps.get(runSteps.size() - 1);
return execStep.results().get(0); return execStep.results().get(0);
} }
} }

View File

@ -298,16 +298,21 @@ public class SqlTestFrameworkConfig
@Override @Override
public void beforeEach(ExtensionContext context) public void beforeEach(ExtensionContext context)
{ {
setConfig(context); makeConfigFromContext(context);
} }
private void setConfig(ExtensionContext context) public void makeConfigFromContext(ExtensionContext context)
{ {
testName = buildTestCaseName(context); testName = buildTestCaseName(context);
method = context.getTestMethod().get(); method = context.getTestMethod().get();
Class<?> testClass = context.getTestClass().get(); Class<?> testClass = context.getTestClass().get();
List<Annotation> annotations = collectAnnotations(testClass, method); List<Annotation> annotations = collectAnnotations(testClass, method);
config = new SqlTestFrameworkConfig(annotations); setConfig(new SqlTestFrameworkConfig(annotations));
}
public void setConfig(SqlTestFrameworkConfig config)
{
this.config = config;
} }
/** /**

View File

@ -130,6 +130,7 @@ public class CalciteTests
public static final String ALL_TYPES_UNIQ_PARQUET = "allTypsUniq.parquet"; public static final String ALL_TYPES_UNIQ_PARQUET = "allTypsUniq.parquet";
public static final String FEW_ROWS_ALL_DATA_PARQUET = "fewRowsAllData.parquet"; public static final String FEW_ROWS_ALL_DATA_PARQUET = "fewRowsAllData.parquet";
public static final String T_ALL_TYPE_PARQUET = "t_alltype.parquet"; public static final String T_ALL_TYPE_PARQUET = "t_alltype.parquet";
public static final String BENCHMARK_DATASOURCE = "benchmark_ds";
public static final String TEST_SUPERUSER_NAME = "testSuperuser"; public static final String TEST_SUPERUSER_NAME = "testSuperuser";
public static final AuthorizerMapper TEST_AUTHORIZER_MAPPER = new AuthorizerMapper(null) public static final AuthorizerMapper TEST_AUTHORIZER_MAPPER = new AuthorizerMapper(null)

View File

@ -44,6 +44,8 @@ import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.parsers.JSONPathSpec; import org.apache.druid.java.util.common.parsers.JSONPathSpec;
import org.apache.druid.query.DataSource; import org.apache.druid.query.DataSource;
import org.apache.druid.query.GlobalTableDataSource; import org.apache.druid.query.GlobalTableDataSource;
@ -64,6 +66,7 @@ import org.apache.druid.query.aggregation.firstlast.last.StringLastAggregatorFac
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory; import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import org.apache.druid.query.groupby.GroupByQueryConfig; import org.apache.druid.query.groupby.GroupByQueryConfig;
import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider;
import org.apache.druid.segment.AutoTypeColumnSchema;
import org.apache.druid.segment.IndexBuilder; import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.IndexSpec; import org.apache.druid.segment.IndexSpec;
import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.QueryableIndex;
@ -71,6 +74,10 @@ import org.apache.druid.segment.SegmentWrangler;
import org.apache.druid.segment.TestIndex; import org.apache.druid.segment.TestIndex;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.StringEncodingStrategy;
import org.apache.druid.segment.generator.GeneratorBasicSchemas;
import org.apache.druid.segment.generator.GeneratorSchemaInfo;
import org.apache.druid.segment.generator.SegmentGenerator;
import org.apache.druid.segment.incremental.IncrementalIndex; import org.apache.druid.segment.incremental.IncrementalIndex;
import org.apache.druid.segment.incremental.IncrementalIndexSchema; import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.segment.join.JoinConditionAnalysis;
@ -79,6 +86,7 @@ import org.apache.druid.segment.join.JoinableFactory;
import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.segment.join.JoinableFactoryWrapper;
import org.apache.druid.segment.join.table.IndexedTableJoinable; import org.apache.druid.segment.join.table.IndexedTableJoinable;
import org.apache.druid.segment.join.table.RowBasedIndexedTable; import org.apache.druid.segment.join.table.RowBasedIndexedTable;
import org.apache.druid.segment.transform.TransformSpec;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.server.QueryScheduler; import org.apache.druid.server.QueryScheduler;
import org.apache.druid.server.QueryStackTests; import org.apache.druid.server.QueryStackTests;
@ -112,6 +120,8 @@ public class TestDataBuilder
public static final String TIMESTAMP_COLUMN = "t"; public static final String TIMESTAMP_COLUMN = "t";
public static final GlobalTableDataSource CUSTOM_TABLE = new GlobalTableDataSource(CalciteTests.BROADCAST_DATASOURCE); public static final GlobalTableDataSource CUSTOM_TABLE = new GlobalTableDataSource(CalciteTests.BROADCAST_DATASOURCE);
public static QueryableIndex QUERYABLE_INDEX_FOR_BENCHMARK_DATASOURCE = null;
public static final JoinableFactory CUSTOM_ROW_TABLE_JOINABLE = new JoinableFactory() public static final JoinableFactory CUSTOM_ROW_TABLE_JOINABLE = new JoinableFactory()
{ {
@Override @Override
@ -978,6 +988,21 @@ public class TestDataBuilder
attachIndexForDrillTestDatasource(segmentWalker, CalciteTests.T_ALL_TYPE_PARQUET, tmpDir); attachIndexForDrillTestDatasource(segmentWalker, CalciteTests.T_ALL_TYPE_PARQUET, tmpDir);
} }
public static void attachIndexesForBenchmarkDatasource(SpecificSegmentsQuerySegmentWalker segmentWalker)
{
final QueryableIndex queryableIndex = getQueryableIndexForBenchmarkDatasource();
segmentWalker.add(
DataSegment.builder()
.dataSource(CalciteTests.BENCHMARK_DATASOURCE)
.interval(Intervals.ETERNITY)
.version("1")
.shardSpec(new NumberedShardSpec(0, 0))
.size(0)
.build(),
queryableIndex);
}
@SuppressWarnings({"rawtypes", "unchecked"}) @SuppressWarnings({"rawtypes", "unchecked"})
private static void attachIndexForDrillTestDatasource( private static void attachIndexForDrillTestDatasource(
SpecificSegmentsQuerySegmentWalker segmentWalker, SpecificSegmentsQuerySegmentWalker segmentWalker,
@ -1014,6 +1039,40 @@ public class TestDataBuilder
.buildMMappedIndex(); .buildMMappedIndex();
} }
public static QueryableIndex getQueryableIndexForBenchmarkDatasource()
{
if (QUERYABLE_INDEX_FOR_BENCHMARK_DATASOURCE == null) {
throw new RuntimeException("Queryable index was not populated for benchmark datasource.");
}
return QUERYABLE_INDEX_FOR_BENCHMARK_DATASOURCE;
}
public static void makeQueryableIndexForBenchmarkDatasource(Closer closer, int rowsPerSegment)
{
if (closer == null) {
throw new RuntimeException("Closer not supplied for generating segments, exiting.");
}
final GeneratorSchemaInfo schemaInfo = GeneratorBasicSchemas.SCHEMA_MAP.get("basic");
final DataSegment dataSegment = schemaInfo.makeSegmentDescriptor(CalciteTests.BENCHMARK_DATASOURCE);
final SegmentGenerator segmentGenerator = closer.register(new SegmentGenerator());
List<DimensionSchema> columnSchemas = schemaInfo.getDimensionsSpec()
.getDimensions()
.stream()
.map(x -> new AutoTypeColumnSchema(x.getName(), null))
.collect(Collectors.toList());
QUERYABLE_INDEX_FOR_BENCHMARK_DATASOURCE = segmentGenerator.generate(
dataSegment,
schemaInfo,
DimensionsSpec.builder().setDimensions(columnSchemas).build(),
TransformSpec.NONE,
IndexSpec.builder().withStringDictionaryEncoding(new StringEncodingStrategy.Utf8()).build(),
Granularities.NONE,
rowsPerSegment
);
}
private static DimensionsSpec getDimensionSpecForDrillDatasource(String datasource) private static DimensionsSpec getDimensionSpecForDrillDatasource(String datasource)
{ {
switch (datasource) { switch (datasource) {