From d0916865d045118e6c704345382b5b4cd7668966 Mon Sep 17 00:00:00 2001 From: Karan Kumar Date: Sat, 1 Jun 2024 22:50:44 +0530 Subject: [PATCH] Fix race in AzureClient factory fetch (#16525) * Fix race in AzureClient factory fetch * Fixing forbidden check. * Renaming variable. --- .../storage/azure/AzureClientFactory.java | 4 +- .../storage/azure/AzureClientFactoryTest.java | 52 +++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/extensions-core/azure-extensions/src/main/java/org/apache/druid/storage/azure/AzureClientFactory.java b/extensions-core/azure-extensions/src/main/java/org/apache/druid/storage/azure/AzureClientFactory.java index 7afde0466d5..a6e1ef2f49e 100644 --- a/extensions-core/azure-extensions/src/main/java/org/apache/druid/storage/azure/AzureClientFactory.java +++ b/extensions-core/azure-extensions/src/main/java/org/apache/druid/storage/azure/AzureClientFactory.java @@ -32,8 +32,8 @@ import org.apache.druid.java.util.common.Pair; import javax.annotation.Nullable; import java.time.Duration; -import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** * Factory class for generating BlobServiceClient objects used for deep storage. @@ -47,7 +47,7 @@ public class AzureClientFactory public AzureClientFactory(AzureAccountConfig config) { this.config = config; - this.cachedBlobServiceClients = new HashMap<>(); + this.cachedBlobServiceClients = new ConcurrentHashMap<>(); } // It's okay to store clients in a map here because all the configs for specifying azure retries are static, and there are only 2 of them. diff --git a/extensions-core/azure-extensions/src/test/java/org/apache/druid/storage/azure/AzureClientFactoryTest.java b/extensions-core/azure-extensions/src/test/java/org/apache/druid/storage/azure/AzureClientFactoryTest.java index 1361a9351c0..795f0054224 100644 --- a/extensions-core/azure-extensions/src/test/java/org/apache/druid/storage/azure/AzureClientFactoryTest.java +++ b/extensions-core/azure-extensions/src/test/java/org/apache/druid/storage/azure/AzureClientFactoryTest.java @@ -24,11 +24,16 @@ import com.azure.core.http.policy.BearerTokenAuthenticationPolicy; import com.azure.storage.blob.BlobServiceClient; import com.azure.storage.common.StorageSharedKeyCredential; import com.google.common.collect.ImmutableMap; +import org.apache.druid.java.util.common.concurrent.Execs; import org.junit.Assert; import org.junit.Test; import java.net.MalformedURLException; import java.net.URL; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; public class AzureClientFactoryTest { @@ -173,4 +178,51 @@ public class AzureClientFactoryTest BlobServiceClient blobServiceClient = azureClientFactory.getBlobServiceClient(null, ACCOUNT); Assert.assertEquals(expectedAccountUrl.toString(), blobServiceClient.getAccountUrl()); } + + @Test + public void test_concurrent_azureClientFactory_gets() throws Exception + { + for (int i = 0; i < 10; i++) { + concurrentAzureClientFactoryGets(); + } + } + + private void concurrentAzureClientFactoryGets() throws Exception + { + final int threads = 100; + String endpointSuffix = "core.nonDefault.windows.net"; + String storageAccountEndpointSuffix = "ABC123.blob.storage.azure.net"; + AzureAccountConfig config = new AzureAccountConfig(); + config.setKey("key"); + config.setEndpointSuffix(endpointSuffix); + config.setStorageAccountEndpointSuffix(storageAccountEndpointSuffix); + final AzureClientFactory localAzureClientFactory = new AzureClientFactory(config); + final URL expectedAccountUrl = new URL( + AzureAccountConfig.DEFAULT_PROTOCOL, + ACCOUNT + "." + storageAccountEndpointSuffix, + "" + ); + + final CountDownLatch latch = new CountDownLatch(threads); + ExecutorService executorService = Execs.multiThreaded(threads, "azure-client-fetcher-%d"); + final AtomicReference failureException = new AtomicReference<>(); + for (int i = 0; i < threads; i++) { + final int retry = i % 2; + executorService.submit(() -> { + try { + latch.countDown(); + latch.await(); + BlobServiceClient blobServiceClient = localAzureClientFactory.getBlobServiceClient(retry, ACCOUNT); + Assert.assertEquals(expectedAccountUrl.toString(), blobServiceClient.getAccountUrl()); + } + catch (Exception e) { + failureException.compareAndSet(null, e); + } + }); + } + executorService.awaitTermination(1000, TimeUnit.MICROSECONDS); + if (failureException.get() != null) { + throw failureException.get(); + } + } }