Limit max concurrency of test cluster nodes to a function of max workers (#51338)

(cherry picked from commit 9a0238c7166e70e467ca61c1353157979fd1598b)
This commit is contained in:
Mark Vieira 2020-01-23 15:20:17 -08:00
parent 40f4f2d267
commit 4f214d20ab
No known key found for this signature in database
GPG Key ID: CA947EF7E6D4B105
5 changed files with 138 additions and 66 deletions

View File

@ -1,12 +1,22 @@
package org.elasticsearch.gradle.testclusters;
import org.elasticsearch.gradle.tool.Boilerplate;
import org.gradle.api.provider.Provider;
import org.gradle.api.services.internal.BuildServiceRegistryInternal;
import org.gradle.api.tasks.CacheableTask;
import org.gradle.api.tasks.Internal;
import org.gradle.api.tasks.Nested;
import org.gradle.api.tasks.testing.Test;
import org.gradle.internal.resources.ResourceLock;
import org.gradle.internal.resources.SharedResource;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import static org.elasticsearch.gradle.testclusters.TestClustersPlugin.THROTTLE_SERVICE_NAME;
/**
* Customized version of Gradle {@link Test} task which tracks a collection of {@link ElasticsearchCluster} as a task input. We must do this
@ -47,4 +57,19 @@ public class RestTestRunnerTask extends Test implements TestClustersAware {
return clusters;
}
@Override
@Internal
public List<ResourceLock> getSharedResources() {
List<ResourceLock> locks = new ArrayList<>(super.getSharedResources());
BuildServiceRegistryInternal serviceRegistry = getServices().get(BuildServiceRegistryInternal.class);
Provider<TestClustersThrottle> throttleProvider = Boilerplate.getBuildService(serviceRegistry, THROTTLE_SERVICE_NAME);
SharedResource resource = serviceRegistry.forService(throttleProvider);
int nodeCount = clusters.stream().mapToInt(cluster -> cluster.getNodes().size()).sum();
if (nodeCount > 0) {
locks.add(resource.getResourceLock(Math.min(nodeCount, resource.getMaxUsages())));
}
return Collections.unmodifiableList(locks);
}
}

View File

@ -21,6 +21,7 @@ package org.elasticsearch.gradle.testclusters;
import org.elasticsearch.gradle.DistributionDownloadPlugin;
import org.elasticsearch.gradle.ReaperPlugin;
import org.elasticsearch.gradle.ReaperService;
import org.elasticsearch.gradle.tool.Boilerplate;
import org.gradle.api.NamedDomainObjectContainer;
import org.gradle.api.Plugin;
import org.gradle.api.Project;
@ -30,53 +31,50 @@ import org.gradle.api.execution.TaskExecutionListener;
import org.gradle.api.invocation.Gradle;
import org.gradle.api.logging.Logger;
import org.gradle.api.logging.Logging;
import org.gradle.api.provider.Provider;
import org.gradle.api.tasks.TaskState;
import java.io.File;
public class TestClustersPlugin implements Plugin<Project> {
private static final String LIST_TASK_NAME = "listTestClusters";
public static final String EXTENSION_NAME = "testClusters";
private static final String REGISTRY_EXTENSION_NAME = "testClustersRegistry";
public static final String THROTTLE_SERVICE_NAME = "testClustersThrottle";
private static final String LIST_TASK_NAME = "listTestClusters";
private static final String REGISTRY_SERVICE_NAME = "testClustersRegistry";
private static final Logger logger = Logging.getLogger(TestClustersPlugin.class);
private ReaperService reaper;
@Override
public void apply(Project project) {
project.getPlugins().apply(DistributionDownloadPlugin.class);
project.getRootProject().getPluginManager().apply(ReaperPlugin.class);
reaper = project.getRootProject().getExtensions().getByType(ReaperService.class);
ReaperService reaper = project.getRootProject().getExtensions().getByType(ReaperService.class);
// enable the DSL to describe clusters
NamedDomainObjectContainer<ElasticsearchCluster> container = createTestClustersContainerExtension(project);
NamedDomainObjectContainer<ElasticsearchCluster> container = createTestClustersContainerExtension(project, reaper);
// provide a task to be able to list defined clusters.
createListClustersTask(project, container);
if (project.getRootProject().getExtensions().findByName(REGISTRY_EXTENSION_NAME) == null) {
TestClustersRegistry registry = project.getRootProject()
.getExtensions()
.create(REGISTRY_EXTENSION_NAME, TestClustersRegistry.class);
// register cluster registry as a global build service
project.getGradle().getSharedServices().registerIfAbsent(REGISTRY_SERVICE_NAME, TestClustersRegistry.class, spec -> {});
// When we know what tasks will run, we claim the clusters of those task to differentiate between clusters
// that are defined in the build script and the ones that will actually be used in this invocation of gradle
// we use this information to determine when the last task that required the cluster executed so that we can
// terminate the cluster right away and free up resources.
configureClaimClustersHook(project.getGradle(), registry);
// register throttle so we only run at most max-workers/2 nodes concurrently
project.getGradle()
.getSharedServices()
.registerIfAbsent(
THROTTLE_SERVICE_NAME,
TestClustersThrottle.class,
spec -> spec.getMaxParallelUsages().set(project.getGradle().getStartParameter().getMaxWorkerCount() / 2)
);
// Before each task, we determine if a cluster needs to be started for that task.
configureStartClustersHook(project.getGradle(), registry);
// After each task we determine if there are clusters that are no longer needed.
configureStopClustersHook(project.getGradle(), registry);
}
// register cluster hooks
project.getRootProject().getPluginManager().apply(TestClustersHookPlugin.class);
}
private NamedDomainObjectContainer<ElasticsearchCluster> createTestClustersContainerExtension(Project project) {
private NamedDomainObjectContainer<ElasticsearchCluster> createTestClustersContainerExtension(Project project, ReaperService reaper) {
// Create an extensions that allows describing clusters
NamedDomainObjectContainer<ElasticsearchCluster> container = project.container(
ElasticsearchCluster.class,
@ -95,52 +93,78 @@ public class TestClustersPlugin implements Plugin<Project> {
);
}
private static void configureClaimClustersHook(Gradle gradle, TestClustersRegistry registry) {
// Once we know all the tasks that need to execute, we claim all the clusters that belong to those and count the
// claims so we'll know when it's safe to stop them.
gradle.getTaskGraph().whenReady(taskExecutionGraph -> {
taskExecutionGraph.getAllTasks()
.stream()
.filter(task -> task instanceof TestClustersAware)
.map(task -> (TestClustersAware) task)
.flatMap(task -> task.getClusters().stream())
.forEach(registry::claimCluster);
});
}
private static void configureStartClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskActionListener() {
@Override
public void beforeActions(Task task) {
if (task instanceof TestClustersAware == false) {
return;
}
// we only start the cluster before the actions, so we'll not start it if the task is up-to-date
TestClustersAware awareTask = (TestClustersAware) task;
awareTask.beforeStart();
awareTask.getClusters().forEach(registry::maybeStartCluster);
static class TestClustersHookPlugin implements Plugin<Project> {
@Override
public void apply(Project project) {
if (project != project.getRootProject()) {
throw new IllegalStateException(this.getClass().getName() + " can only be applied to the root project.");
}
@Override
public void afterActions(Task task) {}
});
}
Provider<TestClustersRegistry> registryProvider = Boilerplate.getBuildService(
project.getGradle().getSharedServices(),
REGISTRY_SERVICE_NAME
);
TestClustersRegistry registry = registryProvider.get();
private static void configureStopClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskExecutionListener() {
@Override
public void afterExecute(Task task, TaskState state) {
if (task instanceof TestClustersAware == false) {
return;
// When we know what tasks will run, we claim the clusters of those task to differentiate between clusters
// that are defined in the build script and the ones that will actually be used in this invocation of gradle
// we use this information to determine when the last task that required the cluster executed so that we can
// terminate the cluster right away and free up resources.
configureClaimClustersHook(project.getGradle(), registry);
// Before each task, we determine if a cluster needs to be started for that task.
configureStartClustersHook(project.getGradle(), registry);
// After each task we determine if there are clusters that are no longer needed.
configureStopClustersHook(project.getGradle(), registry);
}
private static void configureClaimClustersHook(Gradle gradle, TestClustersRegistry registry) {
// Once we know all the tasks that need to execute, we claim all the clusters that belong to those and count the
// claims so we'll know when it's safe to stop them.
gradle.getTaskGraph().whenReady(taskExecutionGraph -> {
taskExecutionGraph.getAllTasks()
.stream()
.filter(task -> task instanceof TestClustersAware)
.map(task -> (TestClustersAware) task)
.flatMap(task -> task.getClusters().stream())
.forEach(registry::claimCluster);
});
}
private static void configureStartClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskActionListener() {
@Override
public void beforeActions(Task task) {
if (task instanceof TestClustersAware == false) {
return;
}
// we only start the cluster before the actions, so we'll not start it if the task is up-to-date
TestClustersAware awareTask = (TestClustersAware) task;
awareTask.beforeStart();
awareTask.getClusters().forEach(registry::maybeStartCluster);
}
// always unclaim the cluster, even if _this_ task is up-to-date, as others might not have been
// and caused the cluster to start.
((TestClustersAware) task).getClusters().forEach(cluster -> registry.stopCluster(cluster, state.getFailure() != null));
}
@Override
public void beforeExecute(Task task) {}
});
@Override
public void afterActions(Task task) {}
});
}
private static void configureStopClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskExecutionListener() {
@Override
public void afterExecute(Task task, TaskState state) {
if (task instanceof TestClustersAware == false) {
return;
}
// always unclaim the cluster, even if _this_ task is up-to-date, as others might not have been
// and caused the cluster to start.
((TestClustersAware) task).getClusters().forEach(cluster -> registry.stopCluster(cluster, state.getFailure() != null));
}
@Override
public void beforeExecute(Task task) {}
});
}
}
}

View File

@ -2,13 +2,15 @@ package org.elasticsearch.gradle.testclusters;
import org.gradle.api.logging.Logger;
import org.gradle.api.logging.Logging;
import org.gradle.api.services.BuildService;
import org.gradle.api.services.BuildServiceParameters;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
public class TestClustersRegistry {
public abstract class TestClustersRegistry implements BuildService<BuildServiceParameters.None> {
private static final Logger logger = Logging.getLogger(TestClustersRegistry.class);
private static final String TESTCLUSTERS_INSPECT_FAILURE = "testclusters.inspect.failure";
private final Boolean allowClusterToSurvive = Boolean.valueOf(System.getProperty(TESTCLUSTERS_INSPECT_FAILURE, "false"));

View File

@ -0,0 +1,6 @@
package org.elasticsearch.gradle.testclusters;
import org.gradle.api.services.BuildService;
import org.gradle.api.services.BuildServiceParameters;
public abstract class TestClustersThrottle implements BuildService<BuildServiceParameters.None> {}

View File

@ -19,12 +19,17 @@
package org.elasticsearch.gradle.tool;
import org.gradle.api.Action;
import org.gradle.api.GradleException;
import org.gradle.api.NamedDomainObjectContainer;
import org.gradle.api.PolymorphicDomainObjectContainer;
import org.gradle.api.Project;
import org.gradle.api.Task;
import org.gradle.api.UnknownTaskException;
import org.gradle.api.plugins.JavaPluginConvention;
import org.gradle.api.provider.Provider;
import org.gradle.api.services.BuildService;
import org.gradle.api.services.BuildServiceRegistration;
import org.gradle.api.services.BuildServiceRegistry;
import org.gradle.api.tasks.SourceSetContainer;
import org.gradle.api.tasks.TaskContainer;
import org.gradle.api.tasks.TaskProvider;
@ -102,4 +107,14 @@ public abstract class Boilerplate {
return task;
}
@SuppressWarnings("unchecked")
public static <T extends BuildService<?>> Provider<T> getBuildService(BuildServiceRegistry registry, String name) {
BuildServiceRegistration<?, ?> registration = registry.getRegistrations().findByName(name);
if (registration == null) {
throw new GradleException("Unable to find build service with name '" + name + "'.");
}
return (Provider<T>) registration.getService();
}
}