AMQ-6602: Fix race condition in TaskRunnerFactory

Fixing a race condition in TaskRunnerFactory where if multiple threads
call createTaskRunner() at the same time some threads might see the
executor as null (if it hasn't finished initializing) leading to the
creation of extra DedicatedTaskRunner objects instead of sharing a
PooledTaskRunner.
This commit is contained in:
Christopher L. Shannon (cshannon) 2017-02-23 10:18:22 -05:00
parent 816f81e605
commit fe5164a404
2 changed files with 102 additions and 15 deletions

View File

@ -25,6 +25,7 @@ import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.activemq.util.ThreadPoolUtils;
import org.slf4j.Logger;
@ -42,7 +43,7 @@ import org.slf4j.LoggerFactory;
public class TaskRunnerFactory implements Executor {
private static final Logger LOG = LoggerFactory.getLogger(TaskRunnerFactory.class);
private ExecutorService executor;
private final AtomicReference<ExecutorService> executorRef = new AtomicReference<>();
private int maxIterationsPerRun;
private String name;
private int priority;
@ -81,15 +82,23 @@ public class TaskRunnerFactory implements Executor {
}
public void init() {
if (initDone.compareAndSet(false, true)) {
if (!initDone.get()) {
// If your OS/JVM combination has a good thread model, you may want to
// avoid using a thread pool to run tasks and use a DedicatedTaskRunner instead.
//AMQ-6602 - lock instead of using compareAndSet to prevent threads from seeing a null value
//for executorRef inside createTaskRunner() on contention and creating a DedicatedTaskRunner
synchronized(this) {
//need to recheck if initDone is true under the lock
if (!initDone.get()) {
if (dedicatedTaskRunner || "true".equalsIgnoreCase(System.getProperty("org.apache.activemq.UseDedicatedTaskRunner"))) {
executor = null;
} else if (executor == null) {
executor = createDefaultExecutor();
executorRef.set(null);
} else {
executorRef.compareAndSet(null, createDefaultExecutor());
}
LOG.debug("Initialized TaskRunnerFactory[{}] using ExecutorService: {}", name, executorRef.get());
initDone.set(true);
}
}
LOG.debug("Initialized TaskRunnerFactory[{}] using ExecutorService: {}", name, executor);
}
}
@ -99,11 +108,11 @@ public class TaskRunnerFactory implements Executor {
* @see ThreadPoolUtils#shutdown(java.util.concurrent.ExecutorService)
*/
public void shutdown() {
ExecutorService executor = executorRef.get();
if (executor != null) {
ThreadPoolUtils.shutdown(executor);
executor = null;
}
initDone.set(false);
clearExecutor();
}
/**
@ -112,11 +121,11 @@ public class TaskRunnerFactory implements Executor {
* @see ThreadPoolUtils#shutdownNow(java.util.concurrent.ExecutorService)
*/
public void shutdownNow() {
ExecutorService executor = executorRef.get();
if (executor != null) {
ThreadPoolUtils.shutdownNow(executor);
executor = null;
}
initDone.set(false);
clearExecutor();
}
/**
@ -125,15 +134,25 @@ public class TaskRunnerFactory implements Executor {
* @see ThreadPoolUtils#shutdownGraceful(java.util.concurrent.ExecutorService)
*/
public void shutdownGraceful() {
ExecutorService executor = executorRef.get();
if (executor != null) {
ThreadPoolUtils.shutdownGraceful(executor, shutdownAwaitTermination);
executor = null;
}
clearExecutor();
}
private void clearExecutor() {
//clear under a lock to prevent threads from seeing initDone == true
//but then getting null from executorRef
synchronized(this) {
executorRef.set(null);
initDone.set(false);
}
}
public TaskRunner createTaskRunner(Task task, String name) {
init();
ExecutorService executor = executorRef.get();
if (executor != null) {
return new PooledTaskRunner(executor, task, maxIterationsPerRun);
} else {
@ -149,6 +168,7 @@ public class TaskRunnerFactory implements Executor {
public void execute(Runnable runnable, String name) {
init();
LOG.trace("Execute[{}] runnable: {}", name, runnable);
ExecutorService executor = executorRef.get();
if (executor != null) {
executor.execute(runnable);
} else {
@ -198,11 +218,11 @@ public class TaskRunnerFactory implements Executor {
}
public ExecutorService getExecutor() {
return executor;
return executorRef.get();
}
public void setExecutor(ExecutorService executor) {
this.executor = executor;
this.executorRef.set(executor);
}
public int getMaxIterationsPerRun() {

View File

@ -0,0 +1,67 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.thread;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.junit.Test;
public class TaskRunnerFactoryTest {
/**
* AMQ-6602 test
* Test contention on createTaskRunner() to make sure that all threads end up
* using a PooledTaskRunner
*
* @throws Exception
*/
@Test
public void testConcurrentTaskRunnerCreaction() throws Exception {
final TaskRunnerFactory factory = new TaskRunnerFactory();
final ExecutorService service = Executors.newFixedThreadPool(10);
final CountDownLatch latch1 = new CountDownLatch(1);
final CountDownLatch latch2 = new CountDownLatch(10);
final List<TaskRunner> runners = Collections.synchronizedList(new ArrayList<>(10));
for (int i = 0; i < 10; i++) {
service.execute(() -> {
try {
latch1.await();
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
runners.add(factory.createTaskRunner(() -> true, "task") );
latch2.countDown();
});
}
latch1.countDown();
latch2.await();
for (TaskRunner runner : runners) {
assertTrue(runner instanceof PooledTaskRunner);
}
}
}