diff --git a/core-java-modules/core-java-concurrency-basic-3/src/main/java/com/baeldung/concurrent/completablefuture/retry/RetryCompletableFuture.java b/core-java-modules/core-java-concurrency-basic-3/src/main/java/com/baeldung/concurrent/completablefuture/retry/RetryCompletableFuture.java new file mode 100644 index 0000000000..41f1309311 --- /dev/null +++ b/core-java-modules/core-java-concurrency-basic-3/src/main/java/com/baeldung/concurrent/completablefuture/retry/RetryCompletableFuture.java @@ -0,0 +1,63 @@ +package com.baeldung.concurrent.completablefuture.retry; + +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.function.Supplier; + +public class RetryCompletableFuture { + public static CompletableFuture retryTask(Supplier supplier, int maxRetries) { + Supplier retryableSupplier = retryFunction(supplier, maxRetries); + return CompletableFuture.supplyAsync(retryableSupplier); + } + + static Supplier retryFunction(Supplier supplier, int maxRetries) { + return () -> { + int retries = 0; + while (retries < maxRetries) { + try { + return supplier.get(); + } catch (Exception e) { + retries++; + } + } + throw new IllegalStateException(String.format("Task failed after %s attempts", maxRetries)); + }; + } + + public static CompletableFuture retryUnsafe(Supplier supplier, int maxRetries) { + CompletableFuture cf = CompletableFuture.supplyAsync(supplier); + sleep(100l); + for (int i = 0; i < maxRetries; i++) { + cf = cf.exceptionally(__ -> supplier.get()); + } + return cf; + } + + public static CompletableFuture retryNesting(Supplier supplier, int maxRetries) { + CompletableFuture cf = CompletableFuture.supplyAsync(supplier); + sleep(100); + for (int i = 0; i < maxRetries; i++) { + cf = cf.thenApply(CompletableFuture::completedFuture) + .exceptionally(__ -> CompletableFuture.supplyAsync(supplier)) + .thenCompose(Function.identity()); + } + return cf; + } + + public static CompletableFuture retryExceptionallyAsync(Supplier supplier, int maxRetries) { + CompletableFuture cf = CompletableFuture.supplyAsync(supplier); + sleep(100); + for (int i = 0; i < maxRetries; i++) { + cf = cf.exceptionallyAsync(__ -> supplier.get()); + } + return cf; + } + + private static void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +} diff --git a/core-java-modules/core-java-concurrency-basic-3/src/test/java/com/baeldung/concurrent/completablefuture/retry/RetryCompletableFutureUnitTest.java b/core-java-modules/core-java-concurrency-basic-3/src/test/java/com/baeldung/concurrent/completablefuture/retry/RetryCompletableFutureUnitTest.java new file mode 100644 index 0000000000..b48039d4a6 --- /dev/null +++ b/core-java-modules/core-java-concurrency-basic-3/src/test/java/com/baeldung/concurrent/completablefuture/retry/RetryCompletableFutureUnitTest.java @@ -0,0 +1,125 @@ +package com.baeldung.concurrent.completablefuture.retry; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import static com.baeldung.concurrent.completablefuture.retry.RetryCompletableFuture.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class RetryCompletableFutureUnitTest { + private AtomicInteger retriesCounter = new AtomicInteger(0); + + @BeforeEach + void beforeEach() { + retriesCounter.set(0); + } + + @Test + void whenRetryingTask_thenReturnsCorrectlyAfterFourInvocations() { + Supplier codeToRun = () -> failFourTimesThenReturn(100); + + CompletableFuture result = retryTask(codeToRun, 10); + + assertThat(result.join()) + .isEqualTo(100); + assertThat(retriesCounter) + .hasValue(4); + } + + @Test + void whenRetryingTask_thenThrowsExceptionAfterThreeInvocations() { + Supplier codeToRun = () -> failFourTimesThenReturn(100); + + CompletableFuture result = retryTask(codeToRun, 3); + + assertThatThrownBy(result::join) + .isInstanceOf(CompletionException.class) + .hasMessageContaining("IllegalStateException: Task failed after 3 attempts"); + } + + @Test + void whenRetryingExceptionally_thenReturnsCorrectlyAfterFourInvocations() { + Supplier codeToRun = () -> failFourTimesThenReturn(100); + + CompletableFuture result = retryUnsafe(codeToRun, 10); + + assertThat(result.join()) + .isEqualTo(100); + assertThat(retriesCounter) + .hasValue(4); + } + + @Test + void whenRetryingExceptionally_thenThrowsExceptionAfterThreeInvocations() { + Supplier codeToRun = () -> failFourTimesThenReturn(100); + + CompletableFuture result = retryUnsafe(codeToRun, 3); + + assertThatThrownBy(result::join) + .isInstanceOf(CompletionException.class) + .hasMessageContaining("RuntimeException: task failed for 3 time(s)"); + } + + @Test + void whenRetryingExceptionallyAsync_thenReturnsCorrectlyAfterFourInvocations() { + Supplier codeToRun = () -> failFourTimesThenReturn(100); + + CompletableFuture result = retryExceptionallyAsync(codeToRun, 10); + + assertThat(result.join()) + .isEqualTo(100); + assertThat(retriesCounter) + .hasValue(4); + } + + @Test + void whenRetryingExceptionallyAsync_thenThrowsExceptionAfterThreeInvocations() { + Supplier codeToRun = () -> failFourTimesThenReturn(100); + + CompletableFuture result = retryExceptionallyAsync(codeToRun, 3); + + assertThatThrownBy(result::join) + .isInstanceOf(CompletionException.class) + .hasMessageContaining("RuntimeException: task failed for 3 time(s)"); + } + + @Test + void whenRetryingNesting_thenReturnsCorrectlyAfterFourInvocations() { + Supplier codeToRun = () -> failFourTimesThenReturn(100); + + CompletableFuture result = retryNesting(codeToRun, 10); + + assertThat(result.join()) + .isEqualTo(100); + assertThat(retriesCounter) + .hasValue(4); + } + + @Test + void whenRetryingNesting_thenThrowsExceptionAfterThreeInvocations() { + Supplier codeToRun = () -> failFourTimesThenReturn(100); + + CompletableFuture result = retryNesting(codeToRun, 3); + + assertThatThrownBy(result::join) + .isInstanceOf(CompletionException.class) + .hasMessageContaining("RuntimeException: task failed for 3 time(s)"); + } + + int failFourTimesThenReturn(int returnValue) { + int retryNr = retriesCounter.get(); + System.out.println(String.format("invocation: %s, thread: %s", retryNr, Thread.currentThread().getName())); + if (retryNr < 4) { + retriesCounter.set(retryNr + 1); + throw new RuntimeException(String.format("task failed for %s time(s)", retryNr)); + } + return returnValue; + } + +}