Limit max concurrency of test cluster nodes to a function of max workers (#51338)

(cherry picked from commit 9a0238c7166e70e467ca61c1353157979fd1598b)
This commit is contained in:
Mark Vieira 2020-01-23 15:20:17 -08:00
parent 40f4f2d267
commit 4f214d20ab
No known key found for this signature in database
GPG Key ID: CA947EF7E6D4B105
5 changed files with 138 additions and 66 deletions

View File

@ -1,12 +1,22 @@
package org.elasticsearch.gradle.testclusters; package org.elasticsearch.gradle.testclusters;
import org.elasticsearch.gradle.tool.Boilerplate;
import org.gradle.api.provider.Provider;
import org.gradle.api.services.internal.BuildServiceRegistryInternal;
import org.gradle.api.tasks.CacheableTask; import org.gradle.api.tasks.CacheableTask;
import org.gradle.api.tasks.Internal;
import org.gradle.api.tasks.Nested; import org.gradle.api.tasks.Nested;
import org.gradle.api.tasks.testing.Test; import org.gradle.api.tasks.testing.Test;
import org.gradle.internal.resources.ResourceLock;
import org.gradle.internal.resources.SharedResource;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import static org.elasticsearch.gradle.testclusters.TestClustersPlugin.THROTTLE_SERVICE_NAME;
/** /**
* Customized version of Gradle {@link Test} task which tracks a collection of {@link ElasticsearchCluster} as a task input. We must do this * Customized version of Gradle {@link Test} task which tracks a collection of {@link ElasticsearchCluster} as a task input. We must do this
@ -47,4 +57,19 @@ public class RestTestRunnerTask extends Test implements TestClustersAware {
return clusters; return clusters;
} }
@Override
@Internal
public List<ResourceLock> getSharedResources() {
List<ResourceLock> locks = new ArrayList<>(super.getSharedResources());
BuildServiceRegistryInternal serviceRegistry = getServices().get(BuildServiceRegistryInternal.class);
Provider<TestClustersThrottle> throttleProvider = Boilerplate.getBuildService(serviceRegistry, THROTTLE_SERVICE_NAME);
SharedResource resource = serviceRegistry.forService(throttleProvider);
int nodeCount = clusters.stream().mapToInt(cluster -> cluster.getNodes().size()).sum();
if (nodeCount > 0) {
locks.add(resource.getResourceLock(Math.min(nodeCount, resource.getMaxUsages())));
}
return Collections.unmodifiableList(locks);
}
} }

View File

@ -21,6 +21,7 @@ package org.elasticsearch.gradle.testclusters;
import org.elasticsearch.gradle.DistributionDownloadPlugin; import org.elasticsearch.gradle.DistributionDownloadPlugin;
import org.elasticsearch.gradle.ReaperPlugin; import org.elasticsearch.gradle.ReaperPlugin;
import org.elasticsearch.gradle.ReaperService; import org.elasticsearch.gradle.ReaperService;
import org.elasticsearch.gradle.tool.Boilerplate;
import org.gradle.api.NamedDomainObjectContainer; import org.gradle.api.NamedDomainObjectContainer;
import org.gradle.api.Plugin; import org.gradle.api.Plugin;
import org.gradle.api.Project; import org.gradle.api.Project;
@ -30,53 +31,50 @@ import org.gradle.api.execution.TaskExecutionListener;
import org.gradle.api.invocation.Gradle; import org.gradle.api.invocation.Gradle;
import org.gradle.api.logging.Logger; import org.gradle.api.logging.Logger;
import org.gradle.api.logging.Logging; import org.gradle.api.logging.Logging;
import org.gradle.api.provider.Provider;
import org.gradle.api.tasks.TaskState; import org.gradle.api.tasks.TaskState;
import java.io.File; import java.io.File;
public class TestClustersPlugin implements Plugin<Project> { public class TestClustersPlugin implements Plugin<Project> {
private static final String LIST_TASK_NAME = "listTestClusters";
public static final String EXTENSION_NAME = "testClusters"; public static final String EXTENSION_NAME = "testClusters";
private static final String REGISTRY_EXTENSION_NAME = "testClustersRegistry"; public static final String THROTTLE_SERVICE_NAME = "testClustersThrottle";
private static final String LIST_TASK_NAME = "listTestClusters";
private static final String REGISTRY_SERVICE_NAME = "testClustersRegistry";
private static final Logger logger = Logging.getLogger(TestClustersPlugin.class); private static final Logger logger = Logging.getLogger(TestClustersPlugin.class);
private ReaperService reaper;
@Override @Override
public void apply(Project project) { public void apply(Project project) {
project.getPlugins().apply(DistributionDownloadPlugin.class); project.getPlugins().apply(DistributionDownloadPlugin.class);
project.getRootProject().getPluginManager().apply(ReaperPlugin.class); project.getRootProject().getPluginManager().apply(ReaperPlugin.class);
reaper = project.getRootProject().getExtensions().getByType(ReaperService.class);
ReaperService reaper = project.getRootProject().getExtensions().getByType(ReaperService.class);
// enable the DSL to describe clusters // enable the DSL to describe clusters
NamedDomainObjectContainer<ElasticsearchCluster> container = createTestClustersContainerExtension(project); NamedDomainObjectContainer<ElasticsearchCluster> container = createTestClustersContainerExtension(project, reaper);
// provide a task to be able to list defined clusters. // provide a task to be able to list defined clusters.
createListClustersTask(project, container); createListClustersTask(project, container);
if (project.getRootProject().getExtensions().findByName(REGISTRY_EXTENSION_NAME) == null) { // register cluster registry as a global build service
TestClustersRegistry registry = project.getRootProject() project.getGradle().getSharedServices().registerIfAbsent(REGISTRY_SERVICE_NAME, TestClustersRegistry.class, spec -> {});
.getExtensions()
.create(REGISTRY_EXTENSION_NAME, TestClustersRegistry.class);
// When we know what tasks will run, we claim the clusters of those task to differentiate between clusters // register throttle so we only run at most max-workers/2 nodes concurrently
// that are defined in the build script and the ones that will actually be used in this invocation of gradle project.getGradle()
// we use this information to determine when the last task that required the cluster executed so that we can .getSharedServices()
// terminate the cluster right away and free up resources. .registerIfAbsent(
configureClaimClustersHook(project.getGradle(), registry); THROTTLE_SERVICE_NAME,
TestClustersThrottle.class,
spec -> spec.getMaxParallelUsages().set(project.getGradle().getStartParameter().getMaxWorkerCount() / 2)
);
// Before each task, we determine if a cluster needs to be started for that task. // register cluster hooks
configureStartClustersHook(project.getGradle(), registry); project.getRootProject().getPluginManager().apply(TestClustersHookPlugin.class);
// After each task we determine if there are clusters that are no longer needed.
configureStopClustersHook(project.getGradle(), registry);
}
} }
private NamedDomainObjectContainer<ElasticsearchCluster> createTestClustersContainerExtension(Project project) { private NamedDomainObjectContainer<ElasticsearchCluster> createTestClustersContainerExtension(Project project, ReaperService reaper) {
// Create an extensions that allows describing clusters // Create an extensions that allows describing clusters
NamedDomainObjectContainer<ElasticsearchCluster> container = project.container( NamedDomainObjectContainer<ElasticsearchCluster> container = project.container(
ElasticsearchCluster.class, ElasticsearchCluster.class,
@ -95,52 +93,78 @@ public class TestClustersPlugin implements Plugin<Project> {
); );
} }
private static void configureClaimClustersHook(Gradle gradle, TestClustersRegistry registry) { static class TestClustersHookPlugin implements Plugin<Project> {
// Once we know all the tasks that need to execute, we claim all the clusters that belong to those and count the @Override
// claims so we'll know when it's safe to stop them. public void apply(Project project) {
gradle.getTaskGraph().whenReady(taskExecutionGraph -> { if (project != project.getRootProject()) {
taskExecutionGraph.getAllTasks() throw new IllegalStateException(this.getClass().getName() + " can only be applied to the root project.");
.stream()
.filter(task -> task instanceof TestClustersAware)
.map(task -> (TestClustersAware) task)
.flatMap(task -> task.getClusters().stream())
.forEach(registry::claimCluster);
});
}
private static void configureStartClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskActionListener() {
@Override
public void beforeActions(Task task) {
if (task instanceof TestClustersAware == false) {
return;
}
// we only start the cluster before the actions, so we'll not start it if the task is up-to-date
TestClustersAware awareTask = (TestClustersAware) task;
awareTask.beforeStart();
awareTask.getClusters().forEach(registry::maybeStartCluster);
} }
@Override Provider<TestClustersRegistry> registryProvider = Boilerplate.getBuildService(
public void afterActions(Task task) {} project.getGradle().getSharedServices(),
}); REGISTRY_SERVICE_NAME
} );
TestClustersRegistry registry = registryProvider.get();
private static void configureStopClustersHook(Gradle gradle, TestClustersRegistry registry) { // When we know what tasks will run, we claim the clusters of those task to differentiate between clusters
gradle.addListener(new TaskExecutionListener() { // that are defined in the build script and the ones that will actually be used in this invocation of gradle
@Override // we use this information to determine when the last task that required the cluster executed so that we can
public void afterExecute(Task task, TaskState state) { // terminate the cluster right away and free up resources.
if (task instanceof TestClustersAware == false) { configureClaimClustersHook(project.getGradle(), registry);
return;
// Before each task, we determine if a cluster needs to be started for that task.
configureStartClustersHook(project.getGradle(), registry);
// After each task we determine if there are clusters that are no longer needed.
configureStopClustersHook(project.getGradle(), registry);
}
private static void configureClaimClustersHook(Gradle gradle, TestClustersRegistry registry) {
// Once we know all the tasks that need to execute, we claim all the clusters that belong to those and count the
// claims so we'll know when it's safe to stop them.
gradle.getTaskGraph().whenReady(taskExecutionGraph -> {
taskExecutionGraph.getAllTasks()
.stream()
.filter(task -> task instanceof TestClustersAware)
.map(task -> (TestClustersAware) task)
.flatMap(task -> task.getClusters().stream())
.forEach(registry::claimCluster);
});
}
private static void configureStartClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskActionListener() {
@Override
public void beforeActions(Task task) {
if (task instanceof TestClustersAware == false) {
return;
}
// we only start the cluster before the actions, so we'll not start it if the task is up-to-date
TestClustersAware awareTask = (TestClustersAware) task;
awareTask.beforeStart();
awareTask.getClusters().forEach(registry::maybeStartCluster);
} }
// always unclaim the cluster, even if _this_ task is up-to-date, as others might not have been
// and caused the cluster to start.
((TestClustersAware) task).getClusters().forEach(cluster -> registry.stopCluster(cluster, state.getFailure() != null));
}
@Override @Override
public void beforeExecute(Task task) {} public void afterActions(Task task) {}
}); });
}
private static void configureStopClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskExecutionListener() {
@Override
public void afterExecute(Task task, TaskState state) {
if (task instanceof TestClustersAware == false) {
return;
}
// always unclaim the cluster, even if _this_ task is up-to-date, as others might not have been
// and caused the cluster to start.
((TestClustersAware) task).getClusters().forEach(cluster -> registry.stopCluster(cluster, state.getFailure() != null));
}
@Override
public void beforeExecute(Task task) {}
});
}
} }
} }

View File

@ -2,13 +2,15 @@ package org.elasticsearch.gradle.testclusters;
import org.gradle.api.logging.Logger; import org.gradle.api.logging.Logger;
import org.gradle.api.logging.Logging; import org.gradle.api.logging.Logging;
import org.gradle.api.services.BuildService;
import org.gradle.api.services.BuildServiceParameters;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
public class TestClustersRegistry { public abstract class TestClustersRegistry implements BuildService<BuildServiceParameters.None> {
private static final Logger logger = Logging.getLogger(TestClustersRegistry.class); private static final Logger logger = Logging.getLogger(TestClustersRegistry.class);
private static final String TESTCLUSTERS_INSPECT_FAILURE = "testclusters.inspect.failure"; private static final String TESTCLUSTERS_INSPECT_FAILURE = "testclusters.inspect.failure";
private final Boolean allowClusterToSurvive = Boolean.valueOf(System.getProperty(TESTCLUSTERS_INSPECT_FAILURE, "false")); private final Boolean allowClusterToSurvive = Boolean.valueOf(System.getProperty(TESTCLUSTERS_INSPECT_FAILURE, "false"));

View File

@ -0,0 +1,6 @@
package org.elasticsearch.gradle.testclusters;
import org.gradle.api.services.BuildService;
import org.gradle.api.services.BuildServiceParameters;
public abstract class TestClustersThrottle implements BuildService<BuildServiceParameters.None> {}

View File

@ -19,12 +19,17 @@
package org.elasticsearch.gradle.tool; package org.elasticsearch.gradle.tool;
import org.gradle.api.Action; import org.gradle.api.Action;
import org.gradle.api.GradleException;
import org.gradle.api.NamedDomainObjectContainer; import org.gradle.api.NamedDomainObjectContainer;
import org.gradle.api.PolymorphicDomainObjectContainer; import org.gradle.api.PolymorphicDomainObjectContainer;
import org.gradle.api.Project; import org.gradle.api.Project;
import org.gradle.api.Task; import org.gradle.api.Task;
import org.gradle.api.UnknownTaskException; import org.gradle.api.UnknownTaskException;
import org.gradle.api.plugins.JavaPluginConvention; import org.gradle.api.plugins.JavaPluginConvention;
import org.gradle.api.provider.Provider;
import org.gradle.api.services.BuildService;
import org.gradle.api.services.BuildServiceRegistration;
import org.gradle.api.services.BuildServiceRegistry;
import org.gradle.api.tasks.SourceSetContainer; import org.gradle.api.tasks.SourceSetContainer;
import org.gradle.api.tasks.TaskContainer; import org.gradle.api.tasks.TaskContainer;
import org.gradle.api.tasks.TaskProvider; import org.gradle.api.tasks.TaskProvider;
@ -102,4 +107,14 @@ public abstract class Boilerplate {
return task; return task;
} }
@SuppressWarnings("unchecked")
public static <T extends BuildService<?>> Provider<T> getBuildService(BuildServiceRegistry registry, String name) {
BuildServiceRegistration<?, ?> registration = registry.getRegistrations().findByName(name);
if (registration == null) {
throw new GradleException("Unable to find build service with name '" + name + "'.");
}
return (Provider<T>) registration.getService();
}
} }