Fix race in AzureClient factory fetch (#16525)

* Fix race in AzureClient factory fetch

* Fixing forbidden check.

* Renaming variable.
This commit is contained in:
Karan Kumar 2024-06-01 22:50:44 +05:30 committed by GitHub
parent b1568fb95b
commit d0916865d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 2 deletions

View File

@ -32,8 +32,8 @@ import org.apache.druid.java.util.common.Pair;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.time.Duration; import java.time.Duration;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/** /**
* Factory class for generating BlobServiceClient objects used for deep storage. * Factory class for generating BlobServiceClient objects used for deep storage.
@ -47,7 +47,7 @@ public class AzureClientFactory
public AzureClientFactory(AzureAccountConfig config) public AzureClientFactory(AzureAccountConfig config)
{ {
this.config = config; this.config = config;
this.cachedBlobServiceClients = new HashMap<>(); this.cachedBlobServiceClients = new ConcurrentHashMap<>();
} }
// It's okay to store clients in a map here because all the configs for specifying azure retries are static, and there are only 2 of them. // It's okay to store clients in a map here because all the configs for specifying azure retries are static, and there are only 2 of them.

View File

@ -24,11 +24,16 @@ import com.azure.core.http.policy.BearerTokenAuthenticationPolicy;
import com.azure.storage.blob.BlobServiceClient; import com.azure.storage.blob.BlobServiceClient;
import com.azure.storage.common.StorageSharedKeyCredential; import com.azure.storage.common.StorageSharedKeyCredential;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.net.MalformedURLException; import java.net.MalformedURLException;
import java.net.URL; import java.net.URL;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
public class AzureClientFactoryTest public class AzureClientFactoryTest
{ {
@ -173,4 +178,51 @@ public class AzureClientFactoryTest
BlobServiceClient blobServiceClient = azureClientFactory.getBlobServiceClient(null, ACCOUNT); BlobServiceClient blobServiceClient = azureClientFactory.getBlobServiceClient(null, ACCOUNT);
Assert.assertEquals(expectedAccountUrl.toString(), blobServiceClient.getAccountUrl()); Assert.assertEquals(expectedAccountUrl.toString(), blobServiceClient.getAccountUrl());
} }
@Test
public void test_concurrent_azureClientFactory_gets() throws Exception
{
for (int i = 0; i < 10; i++) {
concurrentAzureClientFactoryGets();
}
}
private void concurrentAzureClientFactoryGets() throws Exception
{
final int threads = 100;
String endpointSuffix = "core.nonDefault.windows.net";
String storageAccountEndpointSuffix = "ABC123.blob.storage.azure.net";
AzureAccountConfig config = new AzureAccountConfig();
config.setKey("key");
config.setEndpointSuffix(endpointSuffix);
config.setStorageAccountEndpointSuffix(storageAccountEndpointSuffix);
final AzureClientFactory localAzureClientFactory = new AzureClientFactory(config);
final URL expectedAccountUrl = new URL(
AzureAccountConfig.DEFAULT_PROTOCOL,
ACCOUNT + "." + storageAccountEndpointSuffix,
""
);
final CountDownLatch latch = new CountDownLatch(threads);
ExecutorService executorService = Execs.multiThreaded(threads, "azure-client-fetcher-%d");
final AtomicReference<Exception> failureException = new AtomicReference<>();
for (int i = 0; i < threads; i++) {
final int retry = i % 2;
executorService.submit(() -> {
try {
latch.countDown();
latch.await();
BlobServiceClient blobServiceClient = localAzureClientFactory.getBlobServiceClient(retry, ACCOUNT);
Assert.assertEquals(expectedAccountUrl.toString(), blobServiceClient.getAccountUrl());
}
catch (Exception e) {
failureException.compareAndSet(null, e);
}
});
}
executorService.awaitTermination(1000, TimeUnit.MICROSECONDS);
if (failureException.get() != null) {
throw failureException.get();
}
}
} }