Limit max concurrency of test cluster nodes to a function of max workers (#51338)
(cherry picked from commit 9a0238c7166e70e467ca61c1353157979fd1598b)
This commit is contained in:
parent
40f4f2d267
commit
4f214d20ab
|
@ -1,12 +1,22 @@
|
|||
package org.elasticsearch.gradle.testclusters;
|
||||
|
||||
import org.elasticsearch.gradle.tool.Boilerplate;
|
||||
import org.gradle.api.provider.Provider;
|
||||
import org.gradle.api.services.internal.BuildServiceRegistryInternal;
|
||||
import org.gradle.api.tasks.CacheableTask;
|
||||
import org.gradle.api.tasks.Internal;
|
||||
import org.gradle.api.tasks.Nested;
|
||||
import org.gradle.api.tasks.testing.Test;
|
||||
import org.gradle.internal.resources.ResourceLock;
|
||||
import org.gradle.internal.resources.SharedResource;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.gradle.testclusters.TestClustersPlugin.THROTTLE_SERVICE_NAME;
|
||||
|
||||
/**
|
||||
* Customized version of Gradle {@link Test} task which tracks a collection of {@link ElasticsearchCluster} as a task input. We must do this
|
||||
|
@ -47,4 +57,19 @@ public class RestTestRunnerTask extends Test implements TestClustersAware {
|
|||
return clusters;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Internal
|
||||
public List<ResourceLock> getSharedResources() {
|
||||
List<ResourceLock> locks = new ArrayList<>(super.getSharedResources());
|
||||
BuildServiceRegistryInternal serviceRegistry = getServices().get(BuildServiceRegistryInternal.class);
|
||||
Provider<TestClustersThrottle> throttleProvider = Boilerplate.getBuildService(serviceRegistry, THROTTLE_SERVICE_NAME);
|
||||
SharedResource resource = serviceRegistry.forService(throttleProvider);
|
||||
|
||||
int nodeCount = clusters.stream().mapToInt(cluster -> cluster.getNodes().size()).sum();
|
||||
if (nodeCount > 0) {
|
||||
locks.add(resource.getResourceLock(Math.min(nodeCount, resource.getMaxUsages())));
|
||||
}
|
||||
|
||||
return Collections.unmodifiableList(locks);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ package org.elasticsearch.gradle.testclusters;
|
|||
import org.elasticsearch.gradle.DistributionDownloadPlugin;
|
||||
import org.elasticsearch.gradle.ReaperPlugin;
|
||||
import org.elasticsearch.gradle.ReaperService;
|
||||
import org.elasticsearch.gradle.tool.Boilerplate;
|
||||
import org.gradle.api.NamedDomainObjectContainer;
|
||||
import org.gradle.api.Plugin;
|
||||
import org.gradle.api.Project;
|
||||
|
@ -30,53 +31,50 @@ import org.gradle.api.execution.TaskExecutionListener;
|
|||
import org.gradle.api.invocation.Gradle;
|
||||
import org.gradle.api.logging.Logger;
|
||||
import org.gradle.api.logging.Logging;
|
||||
import org.gradle.api.provider.Provider;
|
||||
import org.gradle.api.tasks.TaskState;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class TestClustersPlugin implements Plugin<Project> {
|
||||
|
||||
private static final String LIST_TASK_NAME = "listTestClusters";
|
||||
public static final String EXTENSION_NAME = "testClusters";
|
||||
private static final String REGISTRY_EXTENSION_NAME = "testClustersRegistry";
|
||||
public static final String THROTTLE_SERVICE_NAME = "testClustersThrottle";
|
||||
|
||||
private static final String LIST_TASK_NAME = "listTestClusters";
|
||||
private static final String REGISTRY_SERVICE_NAME = "testClustersRegistry";
|
||||
private static final Logger logger = Logging.getLogger(TestClustersPlugin.class);
|
||||
|
||||
private ReaperService reaper;
|
||||
|
||||
@Override
|
||||
public void apply(Project project) {
|
||||
project.getPlugins().apply(DistributionDownloadPlugin.class);
|
||||
|
||||
project.getRootProject().getPluginManager().apply(ReaperPlugin.class);
|
||||
reaper = project.getRootProject().getExtensions().getByType(ReaperService.class);
|
||||
|
||||
ReaperService reaper = project.getRootProject().getExtensions().getByType(ReaperService.class);
|
||||
|
||||
// enable the DSL to describe clusters
|
||||
NamedDomainObjectContainer<ElasticsearchCluster> container = createTestClustersContainerExtension(project);
|
||||
NamedDomainObjectContainer<ElasticsearchCluster> container = createTestClustersContainerExtension(project, reaper);
|
||||
|
||||
// provide a task to be able to list defined clusters.
|
||||
createListClustersTask(project, container);
|
||||
|
||||
if (project.getRootProject().getExtensions().findByName(REGISTRY_EXTENSION_NAME) == null) {
|
||||
TestClustersRegistry registry = project.getRootProject()
|
||||
.getExtensions()
|
||||
.create(REGISTRY_EXTENSION_NAME, TestClustersRegistry.class);
|
||||
// register cluster registry as a global build service
|
||||
project.getGradle().getSharedServices().registerIfAbsent(REGISTRY_SERVICE_NAME, TestClustersRegistry.class, spec -> {});
|
||||
|
||||
// When we know what tasks will run, we claim the clusters of those task to differentiate between clusters
|
||||
// that are defined in the build script and the ones that will actually be used in this invocation of gradle
|
||||
// we use this information to determine when the last task that required the cluster executed so that we can
|
||||
// terminate the cluster right away and free up resources.
|
||||
configureClaimClustersHook(project.getGradle(), registry);
|
||||
// register throttle so we only run at most max-workers/2 nodes concurrently
|
||||
project.getGradle()
|
||||
.getSharedServices()
|
||||
.registerIfAbsent(
|
||||
THROTTLE_SERVICE_NAME,
|
||||
TestClustersThrottle.class,
|
||||
spec -> spec.getMaxParallelUsages().set(project.getGradle().getStartParameter().getMaxWorkerCount() / 2)
|
||||
);
|
||||
|
||||
// Before each task, we determine if a cluster needs to be started for that task.
|
||||
configureStartClustersHook(project.getGradle(), registry);
|
||||
|
||||
// After each task we determine if there are clusters that are no longer needed.
|
||||
configureStopClustersHook(project.getGradle(), registry);
|
||||
}
|
||||
// register cluster hooks
|
||||
project.getRootProject().getPluginManager().apply(TestClustersHookPlugin.class);
|
||||
}
|
||||
|
||||
private NamedDomainObjectContainer<ElasticsearchCluster> createTestClustersContainerExtension(Project project) {
|
||||
private NamedDomainObjectContainer<ElasticsearchCluster> createTestClustersContainerExtension(Project project, ReaperService reaper) {
|
||||
// Create an extensions that allows describing clusters
|
||||
NamedDomainObjectContainer<ElasticsearchCluster> container = project.container(
|
||||
ElasticsearchCluster.class,
|
||||
|
@ -95,52 +93,78 @@ public class TestClustersPlugin implements Plugin<Project> {
|
|||
);
|
||||
}
|
||||
|
||||
private static void configureClaimClustersHook(Gradle gradle, TestClustersRegistry registry) {
|
||||
// Once we know all the tasks that need to execute, we claim all the clusters that belong to those and count the
|
||||
// claims so we'll know when it's safe to stop them.
|
||||
gradle.getTaskGraph().whenReady(taskExecutionGraph -> {
|
||||
taskExecutionGraph.getAllTasks()
|
||||
.stream()
|
||||
.filter(task -> task instanceof TestClustersAware)
|
||||
.map(task -> (TestClustersAware) task)
|
||||
.flatMap(task -> task.getClusters().stream())
|
||||
.forEach(registry::claimCluster);
|
||||
});
|
||||
}
|
||||
|
||||
private static void configureStartClustersHook(Gradle gradle, TestClustersRegistry registry) {
|
||||
gradle.addListener(new TaskActionListener() {
|
||||
@Override
|
||||
public void beforeActions(Task task) {
|
||||
if (task instanceof TestClustersAware == false) {
|
||||
return;
|
||||
}
|
||||
// we only start the cluster before the actions, so we'll not start it if the task is up-to-date
|
||||
TestClustersAware awareTask = (TestClustersAware) task;
|
||||
awareTask.beforeStart();
|
||||
awareTask.getClusters().forEach(registry::maybeStartCluster);
|
||||
static class TestClustersHookPlugin implements Plugin<Project> {
|
||||
@Override
|
||||
public void apply(Project project) {
|
||||
if (project != project.getRootProject()) {
|
||||
throw new IllegalStateException(this.getClass().getName() + " can only be applied to the root project.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterActions(Task task) {}
|
||||
});
|
||||
}
|
||||
Provider<TestClustersRegistry> registryProvider = Boilerplate.getBuildService(
|
||||
project.getGradle().getSharedServices(),
|
||||
REGISTRY_SERVICE_NAME
|
||||
);
|
||||
TestClustersRegistry registry = registryProvider.get();
|
||||
|
||||
private static void configureStopClustersHook(Gradle gradle, TestClustersRegistry registry) {
|
||||
gradle.addListener(new TaskExecutionListener() {
|
||||
@Override
|
||||
public void afterExecute(Task task, TaskState state) {
|
||||
if (task instanceof TestClustersAware == false) {
|
||||
return;
|
||||
// When we know what tasks will run, we claim the clusters of those task to differentiate between clusters
|
||||
// that are defined in the build script and the ones that will actually be used in this invocation of gradle
|
||||
// we use this information to determine when the last task that required the cluster executed so that we can
|
||||
// terminate the cluster right away and free up resources.
|
||||
configureClaimClustersHook(project.getGradle(), registry);
|
||||
|
||||
// Before each task, we determine if a cluster needs to be started for that task.
|
||||
configureStartClustersHook(project.getGradle(), registry);
|
||||
|
||||
// After each task we determine if there are clusters that are no longer needed.
|
||||
configureStopClustersHook(project.getGradle(), registry);
|
||||
}
|
||||
|
||||
private static void configureClaimClustersHook(Gradle gradle, TestClustersRegistry registry) {
|
||||
// Once we know all the tasks that need to execute, we claim all the clusters that belong to those and count the
|
||||
// claims so we'll know when it's safe to stop them.
|
||||
gradle.getTaskGraph().whenReady(taskExecutionGraph -> {
|
||||
taskExecutionGraph.getAllTasks()
|
||||
.stream()
|
||||
.filter(task -> task instanceof TestClustersAware)
|
||||
.map(task -> (TestClustersAware) task)
|
||||
.flatMap(task -> task.getClusters().stream())
|
||||
.forEach(registry::claimCluster);
|
||||
});
|
||||
}
|
||||
|
||||
private static void configureStartClustersHook(Gradle gradle, TestClustersRegistry registry) {
|
||||
gradle.addListener(new TaskActionListener() {
|
||||
@Override
|
||||
public void beforeActions(Task task) {
|
||||
if (task instanceof TestClustersAware == false) {
|
||||
return;
|
||||
}
|
||||
// we only start the cluster before the actions, so we'll not start it if the task is up-to-date
|
||||
TestClustersAware awareTask = (TestClustersAware) task;
|
||||
awareTask.beforeStart();
|
||||
awareTask.getClusters().forEach(registry::maybeStartCluster);
|
||||
}
|
||||
// always unclaim the cluster, even if _this_ task is up-to-date, as others might not have been
|
||||
// and caused the cluster to start.
|
||||
((TestClustersAware) task).getClusters().forEach(cluster -> registry.stopCluster(cluster, state.getFailure() != null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforeExecute(Task task) {}
|
||||
});
|
||||
@Override
|
||||
public void afterActions(Task task) {}
|
||||
});
|
||||
}
|
||||
|
||||
private static void configureStopClustersHook(Gradle gradle, TestClustersRegistry registry) {
|
||||
gradle.addListener(new TaskExecutionListener() {
|
||||
@Override
|
||||
public void afterExecute(Task task, TaskState state) {
|
||||
if (task instanceof TestClustersAware == false) {
|
||||
return;
|
||||
}
|
||||
// always unclaim the cluster, even if _this_ task is up-to-date, as others might not have been
|
||||
// and caused the cluster to start.
|
||||
((TestClustersAware) task).getClusters().forEach(cluster -> registry.stopCluster(cluster, state.getFailure() != null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforeExecute(Task task) {}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -2,13 +2,15 @@ package org.elasticsearch.gradle.testclusters;
|
|||
|
||||
import org.gradle.api.logging.Logger;
|
||||
import org.gradle.api.logging.Logging;
|
||||
import org.gradle.api.services.BuildService;
|
||||
import org.gradle.api.services.BuildServiceParameters;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public class TestClustersRegistry {
|
||||
public abstract class TestClustersRegistry implements BuildService<BuildServiceParameters.None> {
|
||||
private static final Logger logger = Logging.getLogger(TestClustersRegistry.class);
|
||||
private static final String TESTCLUSTERS_INSPECT_FAILURE = "testclusters.inspect.failure";
|
||||
private final Boolean allowClusterToSurvive = Boolean.valueOf(System.getProperty(TESTCLUSTERS_INSPECT_FAILURE, "false"));
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
package org.elasticsearch.gradle.testclusters;
|
||||
|
||||
import org.gradle.api.services.BuildService;
|
||||
import org.gradle.api.services.BuildServiceParameters;
|
||||
|
||||
public abstract class TestClustersThrottle implements BuildService<BuildServiceParameters.None> {}
|
|
@ -19,12 +19,17 @@
|
|||
package org.elasticsearch.gradle.tool;
|
||||
|
||||
import org.gradle.api.Action;
|
||||
import org.gradle.api.GradleException;
|
||||
import org.gradle.api.NamedDomainObjectContainer;
|
||||
import org.gradle.api.PolymorphicDomainObjectContainer;
|
||||
import org.gradle.api.Project;
|
||||
import org.gradle.api.Task;
|
||||
import org.gradle.api.UnknownTaskException;
|
||||
import org.gradle.api.plugins.JavaPluginConvention;
|
||||
import org.gradle.api.provider.Provider;
|
||||
import org.gradle.api.services.BuildService;
|
||||
import org.gradle.api.services.BuildServiceRegistration;
|
||||
import org.gradle.api.services.BuildServiceRegistry;
|
||||
import org.gradle.api.tasks.SourceSetContainer;
|
||||
import org.gradle.api.tasks.TaskContainer;
|
||||
import org.gradle.api.tasks.TaskProvider;
|
||||
|
@ -102,4 +107,14 @@ public abstract class Boilerplate {
|
|||
|
||||
return task;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T extends BuildService<?>> Provider<T> getBuildService(BuildServiceRegistry registry, String name) {
|
||||
BuildServiceRegistration<?, ?> registration = registry.getRegistrations().findByName(name);
|
||||
if (registration == null) {
|
||||
throw new GradleException("Unable to find build service with name '" + name + "'.");
|
||||
}
|
||||
|
||||
return (Provider<T>) registration.getService();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue