diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java index a95d7e69172..c514beb47d3 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java @@ -250,9 +250,8 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension, * - HOSTNAME = string * - PORT = integer * - * If multiple hosts are provider, the Factory will create a - * {@link LoadBalancingKMSClientProvider} that round-robins requests - * across the provided list of hosts. + * This will always create a {@link LoadBalancingKMSClientProvider} + * if the uri is correct. */ @Override public KeyProvider createProvider(URI providerUri, Configuration conf) @@ -279,30 +278,26 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension, } hostsPart = t[0]; } - return createProvider(providerUri, conf, origUrl, port, hostsPart); + return createProvider(conf, origUrl, port, hostsPart); } return null; } - private KeyProvider createProvider(URI providerUri, Configuration conf, + private KeyProvider createProvider(Configuration conf, URL origUrl, int port, String hostsPart) throws IOException { String[] hosts = hostsPart.split(";"); - if (hosts.length == 1) { - return new KMSClientProvider(providerUri, conf); - } else { - KMSClientProvider[] providers = new KMSClientProvider[hosts.length]; - for (int i = 0; i < hosts.length; i++) { - try { - providers[i] = - new KMSClientProvider( - new URI("kms", origUrl.getProtocol(), hosts[i], port, - origUrl.getPath(), null, null), conf); - } catch (URISyntaxException e) { - throw new IOException("Could not instantiate KMSProvider..", e); - } + KMSClientProvider[] providers = new KMSClientProvider[hosts.length]; + for (int i = 0; i < hosts.length; i++) { + try { + providers[i] = + new KMSClientProvider( + new URI("kms", origUrl.getProtocol(), hosts[i], port, + origUrl.getPath(), null, null), conf); + } catch (URISyntaxException e) { + throw new IOException("Could not instantiate KMSProvider.", e); } - return new LoadBalancingKMSClientProvider(providers, conf); } + return new LoadBalancingKMSClientProvider(providers, conf); } } @@ -1031,7 +1026,11 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension, } catch (InterruptedException e) { Thread.currentThread().interrupt(); } catch (Exception e) { - throw new IOException(e); + if (e instanceof IOException) { + throw (IOException) e; + } else { + throw new IOException(e); + } } } return tokens; diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java index 71d32ff5527..42cd47dd7a5 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java @@ -19,6 +19,7 @@ package org.apache.hadoop.crypto.key.kms; import java.io.IOException; +import java.io.InterruptedIOException; import java.security.GeneralSecurityException; import java.security.NoSuchAlgorithmException; import java.util.Arrays; @@ -31,9 +32,13 @@ import org.apache.hadoop.crypto.key.KeyProvider; import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension; import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion; import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension; +import org.apache.hadoop.fs.CommonConfigurationKeysPublic; +import org.apache.hadoop.io.retry.RetryPolicies; +import org.apache.hadoop.io.retry.RetryPolicy; +import org.apache.hadoop.io.retry.RetryPolicy.RetryAction; +import org.apache.hadoop.security.AccessControlException; import org.apache.hadoop.security.Credentials; import org.apache.hadoop.security.token.Token; -import org.apache.hadoop.util.StringUtils; import org.apache.hadoop.util.Time; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,6 +74,8 @@ public class LoadBalancingKMSClientProvider extends KeyProvider implements private final KMSClientProvider[] providers; private final AtomicInteger currentIdx; + private RetryPolicy retryPolicy = null; + public LoadBalancingKMSClientProvider(KMSClientProvider[] providers, Configuration conf) { this(shuffle(providers), Time.monotonicNow(), conf); @@ -80,24 +87,82 @@ public class LoadBalancingKMSClientProvider extends KeyProvider implements super(conf); this.providers = providers; this.currentIdx = new AtomicInteger((int)(seed % providers.length)); + int maxNumRetries = conf.getInt(CommonConfigurationKeysPublic. + KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, providers.length); + int sleepBaseMillis = conf.getInt(CommonConfigurationKeysPublic. + KMS_CLIENT_FAILOVER_SLEEP_BASE_MILLIS_KEY, + CommonConfigurationKeysPublic. + KMS_CLIENT_FAILOVER_SLEEP_BASE_MILLIS_DEFAULT); + int sleepMaxMillis = conf.getInt(CommonConfigurationKeysPublic. + KMS_CLIENT_FAILOVER_SLEEP_MAX_MILLIS_KEY, + CommonConfigurationKeysPublic. + KMS_CLIENT_FAILOVER_SLEEP_MAX_MILLIS_DEFAULT); + Preconditions.checkState(maxNumRetries >= 0); + Preconditions.checkState(sleepBaseMillis >= 0); + Preconditions.checkState(sleepMaxMillis >= 0); + this.retryPolicy = RetryPolicies.failoverOnNetworkException( + RetryPolicies.TRY_ONCE_THEN_FAIL, maxNumRetries, 0, sleepBaseMillis, + sleepMaxMillis); } @VisibleForTesting - KMSClientProvider[] getProviders() { + public KMSClientProvider[] getProviders() { return providers; } private T doOp(ProviderCallable op, int currPos) throws IOException { + if (providers.length == 0) { + throw new IOException("No providers configured !"); + } IOException ex = null; - for (int i = 0; i < providers.length; i++) { + int numFailovers = 0; + for (int i = 0;; i++, numFailovers++) { KMSClientProvider provider = providers[(currPos + i) % providers.length]; try { return op.call(provider); + } catch (AccessControlException ace) { + // No need to retry on AccessControlException + // and AuthorizationException. + // This assumes all the servers are configured with identical + // permissions and identical key acls. + throw ace; } catch (IOException ioe) { - LOG.warn("KMS provider at [{}] threw an IOException!! {}", - provider.getKMSUrl(), StringUtils.stringifyException(ioe)); + LOG.warn("KMS provider at [{}] threw an IOException: ", + provider.getKMSUrl(), ioe); ex = ioe; + + RetryAction action = null; + try { + action = retryPolicy.shouldRetry(ioe, 0, numFailovers, false); + } catch (Exception e) { + if (e instanceof IOException) { + throw (IOException)e; + } + throw new IOException(e); + } + // make sure each provider is tried at least once, to keep behavior + // compatible with earlier versions of LBKMSCP + if (action.action == RetryAction.RetryDecision.FAIL + && numFailovers >= providers.length - 1) { + LOG.warn("Aborting since the Request has failed with all KMS" + + " providers(depending on {}={} setting and numProviders={})" + + " in the group OR the exception is not recoverable", + CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, + getConf().getInt( + CommonConfigurationKeysPublic. + KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, providers.length), + providers.length); + throw ex; + } + if (((numFailovers + 1) % providers.length) == 0) { + // Sleep only after we try all the providers for every cycle. + try { + Thread.sleep(action.delayMillis); + } catch (InterruptedException e) { + throw new InterruptedIOException("Thread Interrupted"); + } + } } catch (Exception e) { if (e instanceof RuntimeException) { throw (RuntimeException)e; @@ -106,12 +171,6 @@ public class LoadBalancingKMSClientProvider extends KeyProvider implements } } } - if (ex != null) { - LOG.warn("Aborting since the Request has failed with all KMS" - + " providers in the group. !!"); - throw ex; - } - throw new IOException("No providers configured !!"); } private int nextIdx() { diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/fs/CommonConfigurationKeysPublic.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/fs/CommonConfigurationKeysPublic.java index b5f355a8138..4fda2b83320 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/fs/CommonConfigurationKeysPublic.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/fs/CommonConfigurationKeysPublic.java @@ -721,6 +721,35 @@ public class CommonConfigurationKeysPublic { /** Default value for KMS_CLIENT_ENC_KEY_CACHE_EXPIRY (12 hrs)*/ public static final int KMS_CLIENT_ENC_KEY_CACHE_EXPIRY_DEFAULT = 43200000; + /** + * @see + * + * core-default.xml + */ + /** Default value is the number of providers specified. */ + public static final String KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY = + "hadoop.security.kms.client.failover.max.retries"; + + /** + * @see + * + * core-default.xml + */ + public static final String KMS_CLIENT_FAILOVER_SLEEP_BASE_MILLIS_KEY = + "hadoop.security.kms.client.failover.sleep.base.millis"; + /** Default value is 100 ms. */ + public static final int KMS_CLIENT_FAILOVER_SLEEP_BASE_MILLIS_DEFAULT = 100; + + /** + * @see + * + * core-default.xml + */ + public static final String KMS_CLIENT_FAILOVER_SLEEP_MAX_MILLIS_KEY = + "hadoop.security.kms.client.failover.sleep.max.millis"; + /** Default value is 2 secs. */ + public static final int KMS_CLIENT_FAILOVER_SLEEP_MAX_MILLIS_DEFAULT = 2000; + /** * @see * diff --git a/hadoop-common-project/hadoop-common/src/main/resources/core-default.xml b/hadoop-common-project/hadoop-common/src/main/resources/core-default.xml index d481048710e..291608ca426 100644 --- a/hadoop-common-project/hadoop-common/src/main/resources/core-default.xml +++ b/hadoop-common-project/hadoop-common/src/main/resources/core-default.xml @@ -2335,6 +2335,34 @@ + + hadoop.security.kms.client.failover.sleep.base.millis + 100 + + Expert only. The time to wait, in milliseconds, between failover + attempts increases exponentially as a function of the number of + attempts made so far, with a random factor of +/- 50%. This option + specifies the base value used in the failover calculation. The + first failover will retry immediately. The 2nd failover attempt + will delay at least hadoop.security.client.failover.sleep.base.millis + milliseconds. And so on. + + + + + hadoop.security.kms.client.failover.sleep.max.millis + 2000 + + Expert only. The time to wait, in milliseconds, between failover + attempts increases exponentially as a function of the number of + attempts made so far, with a random factor of +/- 50%. This option + specifies the maximum value to wait between failovers. + Specifically, the time between two failover attempts will not + exceed +/- 50% of hadoop.security.client.failover.sleep.max.millis + milliseconds. + + + ipc.server.max.connections 0 diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java index d14dd59c773..7a70e18120b 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java @@ -23,9 +23,12 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.verify; import java.io.IOException; +import java.net.NoRouteToHostException; import java.net.URI; +import java.net.UnknownHostException; import java.security.GeneralSecurityException; import java.security.NoSuchAlgorithmException; @@ -33,6 +36,9 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.crypto.key.KeyProvider; import org.apache.hadoop.crypto.key.KeyProvider.Options; import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension; +import org.apache.hadoop.fs.CommonConfigurationKeysPublic; +import org.apache.hadoop.net.ConnectTimeoutException; +import org.apache.hadoop.security.AccessControlException; import org.apache.hadoop.security.authentication.client.AuthenticationException; import org.apache.hadoop.security.authorize.AuthorizationException; import org.junit.Test; @@ -47,14 +53,17 @@ public class TestLoadBalancingKMSClientProvider { Configuration conf = new Configuration(); KeyProvider kp = new KMSClientProvider.Factory().createProvider(new URI( "kms://http@host1/kms/foo"), conf); - assertTrue(kp instanceof KMSClientProvider); - assertEquals("http://host1/kms/foo/v1/", - ((KMSClientProvider) kp).getKMSUrl()); + assertTrue(kp instanceof LoadBalancingKMSClientProvider); + KMSClientProvider[] providers = + ((LoadBalancingKMSClientProvider) kp).getProviders(); + assertEquals(1, providers.length); + assertEquals(Sets.newHashSet("http://host1/kms/foo/v1/"), + Sets.newHashSet(providers[0].getKMSUrl())); kp = new KMSClientProvider.Factory().createProvider(new URI( "kms://http@host1;host2;host3/kms/foo"), conf); assertTrue(kp instanceof LoadBalancingKMSClientProvider); - KMSClientProvider[] providers = + providers = ((LoadBalancingKMSClientProvider) kp).getProviders(); assertEquals(3, providers.length); assertEquals(Sets.newHashSet("http://host1/kms/foo/v1/", @@ -320,4 +329,298 @@ public class TestLoadBalancingKMSClientProvider { Mockito.verify(p1, Mockito.times(1)).warmUpEncryptedKeys(keyName); Mockito.verify(p2, Mockito.times(1)).warmUpEncryptedKeys(keyName); } -} + + /** + * Tests whether retryPolicy fails immediately, after trying each provider + * once, on encountering IOException which is not SocketException. + * @throws Exception + */ + @Test + public void testClientRetriesWithIOException() throws Exception { + Configuration conf = new Configuration(); + // Setting total failover attempts to . + conf.setInt( + CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 10); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.getMetadata(Mockito.anyString())) + .thenThrow(new IOException("p1")); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.getMetadata(Mockito.anyString())) + .thenThrow(new IOException("p2")); + KMSClientProvider p3 = mock(KMSClientProvider.class); + when(p3.getMetadata(Mockito.anyString())) + .thenThrow(new IOException("p3")); + + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + when(p3.getKMSUrl()).thenReturn("p3"); + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1, p2, p3}, 0, conf); + try { + kp.getMetadata("test3"); + fail("Should fail since all providers threw an IOException"); + } catch (Exception e) { + assertTrue(e instanceof IOException); + } + verify(kp.getProviders()[0], Mockito.times(1)) + .getMetadata(Mockito.eq("test3")); + verify(kp.getProviders()[1], Mockito.times(1)) + .getMetadata(Mockito.eq("test3")); + verify(kp.getProviders()[2], Mockito.times(1)) + .getMetadata(Mockito.eq("test3")); + } + + /** + * Tests that client doesn't retry once it encounters AccessControlException + * from first provider. + * This assumes all the kms servers are configured with identical access to + * keys. + * @throws Exception + */ + @Test + public void testClientRetriesWithAccessControlException() throws Exception { + Configuration conf = new Configuration(); + conf.setInt( + CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 3); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new AccessControlException("p1")); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new IOException("p2")); + KMSClientProvider p3 = mock(KMSClientProvider.class); + when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new IOException("p3")); + + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + when(p3.getKMSUrl()).thenReturn("p3"); + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1, p2, p3}, 0, conf); + try { + kp.createKey("test3", new Options(conf)); + fail("Should fail because provider p1 threw an AccessControlException"); + } catch (Exception e) { + assertTrue(e instanceof AccessControlException); + } + verify(p1, Mockito.times(1)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + verify(p2, Mockito.never()).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + verify(p3, Mockito.never()).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + } + + /** + * Tests that client doesn't retry once it encounters RunTimeException + * from first provider. + * This assumes all the kms servers are configured with identical access to + * keys. + * @throws Exception + */ + @Test + public void testClientRetriesWithRuntimeException() throws Exception { + Configuration conf = new Configuration(); + conf.setInt( + CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 3); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new RuntimeException("p1")); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new IOException("p2")); + + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1, p2}, 0, conf); + try { + kp.createKey("test3", new Options(conf)); + fail("Should fail since provider p1 threw RuntimeException"); + } catch (Exception e) { + assertTrue(e instanceof RuntimeException); + } + verify(p1, Mockito.times(1)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + verify(p2, Mockito.never()).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + } + + /** + * Tests the client retries until it finds a good provider. + * @throws Exception + */ + @Test + public void testClientRetriesWithTimeoutsException() throws Exception { + Configuration conf = new Configuration(); + conf.setInt( + CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 4); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new ConnectTimeoutException("p1")); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new UnknownHostException("p2")); + KMSClientProvider p3 = mock(KMSClientProvider.class); + when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new NoRouteToHostException("p3")); + KMSClientProvider p4 = mock(KMSClientProvider.class); + when(p4.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenReturn( + new KMSClientProvider.KMSKeyVersion("test3", "v1", new byte[0])); + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + when(p3.getKMSUrl()).thenReturn("p3"); + when(p4.getKMSUrl()).thenReturn("p4"); + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1, p2, p3, p4}, 0, conf); + try { + kp.createKey("test3", new Options(conf)); + } catch (Exception e) { + fail("Provider p4 should have answered the request."); + } + verify(p1, Mockito.times(1)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + verify(p2, Mockito.times(1)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + verify(p3, Mockito.times(1)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + verify(p4, Mockito.times(1)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + } + + /** + * Tests the operation succeeds second time after ConnectTimeoutException. + * @throws Exception + */ + @Test + public void testClientRetriesSucceedsSecondTime() throws Exception { + Configuration conf = new Configuration(); + conf.setInt( + CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 3); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new ConnectTimeoutException("p1")) + .thenReturn(new KMSClientProvider.KMSKeyVersion("test3", "v1", + new byte[0])); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new ConnectTimeoutException("p2")); + + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1, p2}, 0, conf); + try { + kp.createKey("test3", new Options(conf)); + } catch (Exception e) { + fail("Provider p1 should have answered the request second time."); + } + verify(p1, Mockito.times(2)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + verify(p2, Mockito.times(1)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + } + + /** + * Tests whether retryPolicy retries specified number of times. + * @throws Exception + */ + @Test + public void testClientRetriesSpecifiedNumberOfTimes() throws Exception { + Configuration conf = new Configuration(); + conf.setInt( + CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 10); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new ConnectTimeoutException("p1")); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new ConnectTimeoutException("p2")); + + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1, p2}, 0, conf); + try { + kp.createKey("test3", new Options(conf)); + fail("Should fail"); + } catch (Exception e) { + assert (e instanceof ConnectTimeoutException); + } + verify(p1, Mockito.times(6)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + verify(p2, Mockito.times(5)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + } + + /** + * Tests whether retryPolicy retries number of times equals to number of + * providers if conf kms.client.failover.max.attempts is not set. + * @throws Exception + */ + @Test + public void testClientRetriesIfMaxAttemptsNotSet() throws Exception { + Configuration conf = new Configuration(); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new ConnectTimeoutException("p1")); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new ConnectTimeoutException("p2")); + + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1, p2}, 0, conf); + try { + kp.createKey("test3", new Options(conf)); + fail("Should fail"); + } catch (Exception e) { + assert (e instanceof ConnectTimeoutException); + } + verify(p1, Mockito.times(2)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + verify(p2, Mockito.times(1)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + } + + /** + * Tests that client reties each provider once, when it encounters + * AuthenticationException wrapped in an IOException from first provider. + * @throws Exception + */ + @Test + public void testClientRetriesWithAuthenticationExceptionWrappedinIOException() + throws Exception { + Configuration conf = new Configuration(); + conf.setInt( + CommonConfigurationKeysPublic.KMS_CLIENT_FAILOVER_MAX_RETRIES_KEY, 3); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new IOException(new AuthenticationException("p1"))); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new IOException(new AuthenticationException("p1"))); + + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + + LoadBalancingKMSClientProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] {p1, p2}, 0, conf); + try { + kp.createKey("test3", new Options(conf)); + fail("Should fail since provider p1 threw AuthenticationException"); + } catch (Exception e) { + assertTrue(e.getCause() instanceof AuthenticationException); + } + verify(p1, Mockito.times(1)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + verify(p2, Mockito.times(1)).createKey(Mockito.eq("test3"), + Mockito.any(Options.class)); + } +} \ No newline at end of file diff --git a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/TestEncryptionZonesWithKMS.java b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/TestEncryptionZonesWithKMS.java index 959e724b58f..6f533625cc7 100644 --- a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/TestEncryptionZonesWithKMS.java +++ b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/TestEncryptionZonesWithKMS.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertTrue; import com.google.common.base.Supplier; import org.apache.hadoop.crypto.key.kms.KMSClientProvider; +import org.apache.hadoop.crypto.key.kms.LoadBalancingKMSClientProvider; import org.apache.hadoop.crypto.key.kms.server.MiniKMS; import org.apache.hadoop.security.Credentials; import org.apache.hadoop.security.UserGroupInformation; @@ -69,14 +70,21 @@ public class TestEncryptionZonesWithKMS extends TestEncryptionZones { protected void setProvider() { } + private KMSClientProvider getKMSClientProvider() { + LoadBalancingKMSClientProvider lbkmscp = + (LoadBalancingKMSClientProvider) Whitebox + .getInternalState(cluster.getNamesystem().getProvider(), "extension"); + assert lbkmscp.getProviders().length == 1; + return lbkmscp.getProviders()[0]; + } + @Test(timeout = 120000) public void testCreateEZPopulatesEDEKCache() throws Exception { final Path zonePath = new Path("/TestEncryptionZone"); fsWrapper.mkdir(zonePath, FsPermission.getDirDefault(), false); dfsAdmin.createEncryptionZone(zonePath, TEST_KEY, NO_TRASH); @SuppressWarnings("unchecked") - KMSClientProvider kcp = (KMSClientProvider) Whitebox - .getInternalState(cluster.getNamesystem().getProvider(), "extension"); + KMSClientProvider kcp = getKMSClientProvider(); assertTrue(kcp.getEncKeyQueueSize(TEST_KEY) > 0); } @@ -110,8 +118,7 @@ public class TestEncryptionZonesWithKMS extends TestEncryptionZones { dfsAdmin.createEncryptionZone(zonePath, anotherKey, NO_TRASH); @SuppressWarnings("unchecked") - KMSClientProvider spy = (KMSClientProvider) Whitebox - .getInternalState(cluster.getNamesystem().getProvider(), "extension"); + KMSClientProvider spy = getKMSClientProvider(); assertTrue("key queue is empty after creating encryption zone", spy.getEncKeyQueueSize(TEST_KEY) > 0); @@ -122,9 +129,7 @@ public class TestEncryptionZonesWithKMS extends TestEncryptionZones { GenericTestUtils.waitFor(new Supplier() { @Override public Boolean get() { - final KMSClientProvider kspy = (KMSClientProvider) Whitebox - .getInternalState(cluster.getNamesystem().getProvider(), - "extension"); + final KMSClientProvider kspy = getKMSClientProvider(); return kspy.getEncKeyQueueSize(TEST_KEY) > 0; } }, 1000, 60000);