diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java index c2e94429ce2..2283386f207 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java @@ -19,28 +19,14 @@ package org.apache.druid.msq.test; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.inject.Injector; -import com.google.inject.Module; -import org.apache.calcite.rel.RelRoot; -import org.apache.druid.guice.DruidInjectorBuilder; -import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.sql.MSQTaskSqlEngine; -import org.apache.druid.query.groupby.TestGroupByBuffers; -import org.apache.druid.server.QueryLifecycleFactory; import org.apache.druid.sql.calcite.BaseCalciteQueryTest; import org.apache.druid.sql.calcite.CalciteJoinQueryTest; import org.apache.druid.sql.calcite.QueryTestBuilder; import org.apache.druid.sql.calcite.SqlTestFrameworkConfig; -import org.apache.druid.sql.calcite.TempDirProducer; import org.apache.druid.sql.calcite.planner.JoinAlgorithm; import org.apache.druid.sql.calcite.planner.PlannerContext; -import org.apache.druid.sql.calcite.run.EngineFeature; -import org.apache.druid.sql.calcite.run.QueryMaker; -import org.apache.druid.sql.calcite.run.SqlEngine; -import org.apache.druid.sql.calcite.util.SqlTestFramework.StandardComponentSupplier; import java.util.Map; @@ -52,7 +38,6 @@ public class CalciteSelectJoinQueryMSQTest /** * Run all tests with {@link JoinAlgorithm#BROADCAST}. */ - @SqlTestFrameworkConfig.ComponentSupplier(BroadcastJoinComponentSupplier.class) public static class BroadcastTest extends Base { @Override @@ -61,12 +46,17 @@ public class CalciteSelectJoinQueryMSQTest return super.testBuilder() .verifyNativeQueries(new VerifyMSQSupportedNativeQueriesPredicate()); } + + @Override + protected JoinAlgorithm joinAlgorithm() + { + return JoinAlgorithm.BROADCAST; + } } /** * Run all tests with {@link JoinAlgorithm#SORT_MERGE}. */ - @SqlTestFrameworkConfig.ComponentSupplier(SortMergeJoinComponentSupplier.class) public static class SortMergeTest extends Base { @Override @@ -83,101 +73,32 @@ public class CalciteSelectJoinQueryMSQTest return super.testBuilder() .verifyNativeQueries(xs -> false); } + + @Override + protected JoinAlgorithm joinAlgorithm() + { + return JoinAlgorithm.SORT_MERGE; + + } } + @SqlTestFrameworkConfig.ComponentSupplier(StandardMSQComponentSupplier.class) public abstract static class Base extends CalciteJoinQueryTest { + protected abstract JoinAlgorithm joinAlgorithm(); + @Override protected QueryTestBuilder testBuilder() { - Map defaultCtx = ImmutableMap.builder() - .putAll(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT) - .put(PlannerContext.CTX_SQL_JOIN_ALGORITHM, joinAlgorithm1().toString()) - .build(); - - -// isSortBasedJoin() - return -// new QueryTestBuilder(new CalciteTestConfig(defaultCtx, true)) - new QueryTestBuilder(new CalciteTestConfig(true)) + Map defaultCtx = ImmutableMap.builder() + .putAll(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT) + .put(PlannerContext.CTX_SQL_JOIN_ALGORITHM, joinAlgorithm().toString()) + .build(); + return new QueryTestBuilder(new CalciteTestConfig(defaultCtx, true)) .addCustomRunner( new ExtractResultsFactory( () -> (MSQTestOverlordServiceClient) ((MSQTaskSqlEngine) queryFramework().engine()).overlordClient())) .skipVectorize(true); } - - private JoinAlgorithm joinAlgorithm1() - { - return isSortBasedJoin() ? JoinAlgorithm.SORT_MERGE : JoinAlgorithm.BROADCAST; - - } - } - - protected static class SortMergeJoinComponentSupplier extends AbstractJoinComponentSupplier - { - public SortMergeJoinComponentSupplier(TempDirProducer tempFolderProducer) - { - super(tempFolderProducer, JoinAlgorithm.SORT_MERGE); - } - } - - protected static class BroadcastJoinComponentSupplier extends AbstractJoinComponentSupplier - { - public BroadcastJoinComponentSupplier(TempDirProducer tempFolderProducer) - { - super(tempFolderProducer, JoinAlgorithm.BROADCAST); - } - } - - protected abstract static class AbstractJoinComponentSupplier extends StandardComponentSupplier - { - private JoinAlgorithm joinAlgorithm; - - public AbstractJoinComponentSupplier(TempDirProducer tempFolderProducer, JoinAlgorithm joinAlgorithm) - { - super(tempFolderProducer); - this.joinAlgorithm = joinAlgorithm; - } - - @Override - public void configureGuice(DruidInjectorBuilder builder) - { - super.configureGuice(builder); - builder.addModules( - CalciteMSQTestsHelper.fetchModules(tempDirProducer::newTempFolder, TestGroupByBuffers.createDefault()).toArray(new Module[0]) - ); - } - - @Override - public SqlEngine createEngine( - QueryLifecycleFactory qlf, - ObjectMapper queryJsonMapper, - Injector injector - ) - { - final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); - final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( - queryJsonMapper, - injector, - new MSQTestTaskActionClient(queryJsonMapper, injector), - workerMemoryParameters, - ImmutableList.of() - ); - return new MSQTaskSqlEngine(indexingServiceClient, queryJsonMapper) - { - @Override - public boolean featureAvailable(EngineFeature feature) - { - return super.featureAvailable(feature); - } - - @Override - public QueryMaker buildQueryMakerForSelect(RelRoot relRoot, PlannerContext plannerContext) - { - plannerContext.queryContextMap().put(PlannerContext.CTX_SQL_JOIN_ALGORITHM, joinAlgorithm.toString()); - return super.buildQueryMakerForSelect(relRoot, plannerContext); - } - }; - } } }