diff --git a/core-java-modules/core-java-concurrency-basic-3/src/main/java/com/baeldung/concurrent/completablefuture/threadpool/CustomCompletableFuture.java b/core-java-modules/core-java-concurrency-basic-3/src/main/java/com/baeldung/concurrent/completablefuture/threadpool/CustomCompletableFuture.java new file mode 100644 index 0000000000..1f3997768e --- /dev/null +++ b/core-java-modules/core-java-concurrency-basic-3/src/main/java/com/baeldung/concurrent/completablefuture/threadpool/CustomCompletableFuture.java @@ -0,0 +1,28 @@ +package com.baeldung.concurrent.completablefuture.threadpool; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.function.Supplier; + +public class CustomCompletableFuture extends CompletableFuture { + private static final Executor executor = Executors.newSingleThreadExecutor(runnable -> new Thread(runnable, "Custom-Single-Thread")); + + public static CustomCompletableFuture supplyAsync(Supplier supplier) { + CustomCompletableFuture future = new CustomCompletableFuture<>(); + executor.execute(() -> { + try { + future.complete(supplier.get()); + } catch (Exception ex) { + future.completeExceptionally(ex); + } + }); + return future; + } + + @Override + public Executor defaultExecutor() { + return executor; + } + +} \ No newline at end of file diff --git a/core-java-modules/core-java-concurrency-basic-3/src/test/java/com/baeldung/concurrent/completablefuture/threadpool/CompletableFutureThreadPoolUnitTest.java b/core-java-modules/core-java-concurrency-basic-3/src/test/java/com/baeldung/concurrent/completablefuture/threadpool/CompletableFutureThreadPoolUnitTest.java new file mode 100644 index 0000000000..4f94f36131 --- /dev/null +++ b/core-java-modules/core-java-concurrency-basic-3/src/test/java/com/baeldung/concurrent/completablefuture/threadpool/CompletableFutureThreadPoolUnitTest.java @@ -0,0 +1,82 @@ +package com.baeldung.concurrent.completablefuture.threadpool; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +import org.junit.jupiter.api.Test; + +public class CompletableFutureThreadPoolUnitTest { + + @Test + void whenUsingNonAsync_thenUsesMainThread() { + CompletableFuture name = CompletableFuture.supplyAsync(() -> "Baeldung"); + + CompletableFuture nameLength = name.thenApply(value -> { + printCurrentThread(); + return value.length(); + }); + + assertThat(nameLength).isCompletedWithValue(8); + } + + @Test + void whenUsingNonAsync_thenUsesCallersThread() throws InterruptedException { + Runnable test = () -> { + CompletableFuture name = CompletableFuture.supplyAsync(() -> "Baeldung"); + + CompletableFuture nameLength = name.thenApply(value -> { + printCurrentThread(); + return value.length(); + }); + + assertThat(nameLength).isCompletedWithValue(8); + }; + + new Thread(test, "test-thread").start(); + Thread.sleep(100l); + } + + @Test + void whenUsingAsync_thenUsesCommonPool() { + CompletableFuture name = CompletableFuture.supplyAsync(() -> "Baeldung"); + + CompletableFuture nameLength = name.thenApplyAsync(value -> { + printCurrentThread(); + return value.length(); + }); + + assertThat(nameLength).isCompletedWithValue(8); + } + + @Test + void whenUsingAsync_thenUsesCustomExecutor() { + Executor testExecutor = Executors.newFixedThreadPool(5); + CompletableFuture name = CompletableFuture.supplyAsync(() -> "Baeldung"); + + CompletableFuture nameLength = name.thenApplyAsync(value -> { + printCurrentThread(); + return value.length(); + }, testExecutor); + + assertThat(nameLength).isCompletedWithValue(8); + } + + @Test + void whenOverridingDefaultThreadPool_thenUsesCustomExecutor() { + CompletableFuture name = CustomCompletableFuture.supplyAsync(() -> "Baeldung"); + + CompletableFuture nameLength = name.thenApplyAsync(value -> { + printCurrentThread(); + return value.length(); + }); + + assertThat(nameLength).isCompletedWithValue(8); + } + + private static void printCurrentThread() { + System.out.println(Thread.currentThread().getName()); + } +}