Merge pull request #11478 from hkhan/JAVA-8335-fix-concurrency-test

[JAVA-8335] Fix intermittent unit test failure
This commit is contained in:
kwoyke 2021-11-22 08:03:10 +01:00 committed by GitHub
commit f3f1a60d57
2 changed files with 30 additions and 27 deletions

View File

@ -3,17 +3,18 @@ package com.baeldung.abaproblem;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly;
public class Account { public class Account {
private AtomicInteger balance; private final AtomicInteger balance;
private AtomicInteger transactionCount; private final AtomicInteger transactionCount;
private ThreadLocal<Integer> currentThreadCASFailureCount; private final ThreadLocal<Integer> currentThreadCASFailureCount;
public Account() { public Account() {
this.balance = new AtomicInteger(0); this.balance = new AtomicInteger(0);
this.transactionCount = new AtomicInteger(0); this.transactionCount = new AtomicInteger(0);
this.currentThreadCASFailureCount = new ThreadLocal<>(); this.currentThreadCASFailureCount = ThreadLocal.withInitial(() -> 0);
this.currentThreadCASFailureCount.set(0);
} }
public int getBalance() { public int getBalance() {
@ -43,11 +44,7 @@ public class Account {
private void maybeWait() { private void maybeWait() {
if ("thread1".equals(Thread.currentThread().getName())) { if ("thread1".equals(Thread.currentThread().getName())) {
try { sleepUninterruptibly(2, TimeUnit.SECONDS);
TimeUnit.SECONDS.sleep(2);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
} }
} }

View File

@ -1,8 +1,13 @@
package com.baeldung.abaproblem; package com.baeldung.abaproblem;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@ -30,45 +35,39 @@ public class AccountUnitTest {
assertTrue(account.deposit(moneyToDeposit)); assertTrue(account.deposit(moneyToDeposit));
assertEquals(moneyToDeposit, account.getBalance()); assertEquals(moneyToDeposit, account.getBalance());
assertEquals(1, account.getTransactionCount());
} }
@Test @Test
public void withdrawTest() throws InterruptedException { public void withdrawTest() {
final int defaultBalance = 50; final int defaultBalance = 50;
final int moneyToWithdraw = 20; final int moneyToWithdraw = 20;
account.deposit(defaultBalance); account.deposit(defaultBalance);
assertTrue(account.withdraw(moneyToWithdraw)); assertTrue(account.withdraw(moneyToWithdraw));
assertEquals(defaultBalance - moneyToWithdraw, account.getBalance()); assertEquals(defaultBalance - moneyToWithdraw, account.getBalance());
} }
@Test @Test
public void abaProblemTest() throws InterruptedException { public void abaProblemTest() throws Exception {
final int defaultBalance = 50; final int defaultBalance = 50;
final int amountToWithdrawByThread1 = 20; final int amountToWithdrawByThread1 = 20;
final int amountToWithdrawByThread2 = 10; final int amountToWithdrawByThread2 = 10;
final int amountToDepositByThread2 = 10; final int amountToDepositByThread2 = 10;
assertEquals(0, account.getTransactionCount());
assertEquals(0, account.getCurrentThreadCASFailureCount());
account.deposit(defaultBalance); account.deposit(defaultBalance);
assertEquals(1, account.getTransactionCount());
Thread thread1 = new Thread(() -> {
Runnable thread1 = () -> {
// this will take longer due to the name of the thread // this will take longer due to the name of the thread
assertTrue(account.withdraw(amountToWithdrawByThread1)); assertTrue(account.withdraw(amountToWithdrawByThread1));
// thread 1 fails to capture ABA problem // thread 1 fails to capture ABA problem
assertNotEquals(1, account.getCurrentThreadCASFailureCount()); assertNotEquals(1, account.getCurrentThreadCASFailureCount());
};
}, "thread1"); Runnable thread2 = () -> {
Thread thread2 = new Thread(() -> {
assertTrue(account.deposit(amountToDepositByThread2)); assertTrue(account.deposit(amountToDepositByThread2));
assertEquals(defaultBalance + amountToDepositByThread2, account.getBalance()); assertEquals(defaultBalance + amountToDepositByThread2, account.getBalance());
@ -79,12 +78,13 @@ public class AccountUnitTest {
assertEquals(defaultBalance, account.getBalance()); assertEquals(defaultBalance, account.getBalance());
assertEquals(0, account.getCurrentThreadCASFailureCount()); assertEquals(0, account.getCurrentThreadCASFailureCount());
}, "thread2"); };
thread1.start(); Future<?> future1 = getSingleThreadExecutorService("thread1").submit(thread1);
thread2.start(); Future<?> future2 = getSingleThreadExecutorService("thread2").submit(thread2);
thread1.join();
thread2.join(); future1.get();
future2.get();
// compareAndSet operation succeeds for thread 1 // compareAndSet operation succeeds for thread 1
assertEquals(defaultBalance - amountToWithdrawByThread1, account.getBalance()); assertEquals(defaultBalance - amountToWithdrawByThread1, account.getBalance());
@ -95,4 +95,10 @@ public class AccountUnitTest {
// thread 2 did two modifications as well // thread 2 did two modifications as well
assertEquals(4, account.getTransactionCount()); assertEquals(4, account.getTransactionCount());
} }
private static ExecutorService getSingleThreadExecutorService(String threadName) {
return Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder().setNameFormat(threadName).build()
);
}
} }