From 0512e50d6ed93d8192f7c9721798cc593d10b80a Mon Sep 17 00:00:00 2001 From: Andrew Wang Date: Wed, 25 Feb 2015 21:15:44 -0800 Subject: [PATCH] HADOOP-11620. Add support for load balancing across a group of KMS for HA. Contributed by Arun Suresh. (cherry picked from commit 71385f9b70e22618db3f3d2b2c6dca3b1e82c317) --- .../hadoop-common/CHANGES.txt | 3 + .../crypto/key/kms/KMSClientProvider.java | 84 ++++- .../kms/LoadBalancingKMSClientProvider.java | 347 ++++++++++++++++++ .../TestLoadBalancingKMSClientProvider.java | 166 +++++++++ .../hadoop/crypto/key/kms/server/TestKMS.java | 114 +++--- 5 files changed, 654 insertions(+), 60 deletions(-) create mode 100644 hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java create mode 100644 hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java diff --git a/hadoop-common-project/hadoop-common/CHANGES.txt b/hadoop-common-project/hadoop-common/CHANGES.txt index 9308c3515e5..34eeb45446c 100644 --- a/hadoop-common-project/hadoop-common/CHANGES.txt +++ b/hadoop-common-project/hadoop-common/CHANGES.txt @@ -237,6 +237,9 @@ Release 2.7.0 - UNRELEASED HADOOP-11231. Remove dead code in ServletUtil. (Li Lu via wheat9) + HADOOP-11620. Add support for load balancing across a group of KMS for HA. + (Arun Suresh via wang) + BUG FIXES HADOOP-11512. Use getTrimmedStrings when reading serialization keys diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java index a2cea409a3d..149424fb87d 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/KMSClientProvider.java @@ -52,6 +52,7 @@ import java.io.Writer; import java.lang.reflect.UndeclaredThrowableException; import java.net.HttpURLConnection; import java.net.InetSocketAddress; +import java.net.MalformedURLException; import java.net.SocketTimeoutException; import java.net.URI; import java.net.URISyntaxException; @@ -74,6 +75,7 @@ import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; /** * KMS client KeyProvider implementation. @@ -221,14 +223,71 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension, */ public static class Factory extends KeyProviderFactory { + /** + * This provider expects URIs in the following form : + * kms://@/ + * + * where : + * - PROTO = http or https + * - AUTHORITY = [:] + * - HOSTS = [;] + * - HOSTNAME = string + * - PORT = integer + * + * If multiple hosts are provider, the Factory will create a + * {@link LoadBalancingKMSClientProvider} that round-robins requests + * across the provided list of hosts. + */ @Override - public KeyProvider createProvider(URI providerName, Configuration conf) + public KeyProvider createProvider(URI providerUri, Configuration conf) throws IOException { - if (SCHEME_NAME.equals(providerName.getScheme())) { - return new KMSClientProvider(providerName, conf); + if (SCHEME_NAME.equals(providerUri.getScheme())) { + URL origUrl = new URL(extractKMSPath(providerUri).toString()); + String authority = origUrl.getAuthority(); + // check for ';' which delimits the backup hosts + if (Strings.isNullOrEmpty(authority)) { + throw new IOException( + "No valid authority in kms uri [" + origUrl + "]"); + } + // Check if port is present in authority + // In the current scheme, all hosts have to run on the same port + int port = -1; + String hostsPart = authority; + if (authority.contains(":")) { + String[] t = authority.split(":"); + try { + port = Integer.parseInt(t[1]); + } catch (Exception e) { + throw new IOException( + "Could not parse port in kms uri [" + origUrl + "]"); + } + hostsPart = t[0]; + } + return createProvider(providerUri, conf, origUrl, port, hostsPart); } return null; } + + private KeyProvider createProvider(URI providerUri, Configuration conf, + URL origUrl, int port, String hostsPart) throws IOException { + String[] hosts = hostsPart.split(";"); + if (hosts.length == 1) { + return new KMSClientProvider(providerUri, conf); + } else { + KMSClientProvider[] providers = new KMSClientProvider[hosts.length]; + for (int i = 0; i < hosts.length; i++) { + try { + providers[i] = + new KMSClientProvider( + new URI("kms", origUrl.getProtocol(), hosts[i], port, + origUrl.getPath(), null, null), conf); + } catch (URISyntaxException e) { + throw new IOException("Could not instantiate KMSProvider..", e); + } + } + return new LoadBalancingKMSClientProvider(providers, conf); + } + } } public static T checkNotNull(T o, String name) @@ -302,10 +361,8 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension, public KMSClientProvider(URI uri, Configuration conf) throws IOException { super(conf); - Path path = ProviderUtils.unnestUri(uri); - URL url = path.toUri().toURL(); - kmsUrl = createServiceURL(url); - if ("https".equalsIgnoreCase(url.getProtocol())) { + kmsUrl = createServiceURL(extractKMSPath(uri)); + if ("https".equalsIgnoreCase(new URL(kmsUrl).getProtocol())) { sslFactory = new SSLFactory(SSLFactory.Mode.CLIENT, conf); try { sslFactory.init(); @@ -346,8 +403,12 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension, .getCurrentUser(); } - private String createServiceURL(URL url) throws IOException { - String str = url.toExternalForm(); + private static Path extractKMSPath(URI uri) throws MalformedURLException, IOException { + return ProviderUtils.unnestUri(uri); + } + + private static String createServiceURL(Path path) throws IOException { + String str = new URL(path.toString()).toExternalForm(); if (str.endsWith("/")) { str = str.substring(0, str.length() - 1); } @@ -853,4 +914,9 @@ public class KMSClientProvider extends KeyProvider implements CryptoExtension, } } } + + @VisibleForTesting + String getKMSUrl() { + return kmsUrl; + } } diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java new file mode 100644 index 00000000000..c1579e71326 --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/kms/LoadBalancingKMSClientProvider.java @@ -0,0 +1,347 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.crypto.key.kms; + +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.crypto.key.KeyProvider; +import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension; +import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion; +import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.util.Time; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.annotations.VisibleForTesting; + +/** + * A simple LoadBalancing KMSClientProvider that round-robins requests + * across a provided array of KMSClientProviders. It also retries failed + * requests on the next available provider in the load balancer group. It + * only retries failed requests that result in an IOException, sending back + * all other Exceptions to the caller without retry. + */ +public class LoadBalancingKMSClientProvider extends KeyProvider implements + CryptoExtension, + KeyProviderDelegationTokenExtension.DelegationTokenExtension { + + public static Logger LOG = + LoggerFactory.getLogger(LoadBalancingKMSClientProvider.class); + + static interface ProviderCallable { + public T call(KMSClientProvider provider) throws IOException, Exception; + } + + @SuppressWarnings("serial") + static class WrapperException extends RuntimeException { + public WrapperException(Throwable cause) { + super(cause); + } + } + + private final KMSClientProvider[] providers; + private final AtomicInteger currentIdx; + + public LoadBalancingKMSClientProvider(KMSClientProvider[] providers, + Configuration conf) { + this(shuffle(providers), Time.monotonicNow(), conf); + } + + @VisibleForTesting + LoadBalancingKMSClientProvider(KMSClientProvider[] providers, long seed, + Configuration conf) { + super(conf); + this.providers = providers; + this.currentIdx = new AtomicInteger((int)(seed % providers.length)); + } + + @VisibleForTesting + KMSClientProvider[] getProviders() { + return providers; + } + + private T doOp(ProviderCallable op, int currPos) + throws IOException { + IOException ex = null; + for (int i = 0; i < providers.length; i++) { + KMSClientProvider provider = providers[(currPos + i) % providers.length]; + try { + return op.call(provider); + } catch (IOException ioe) { + LOG.warn("KMS provider at [{}] threw an IOException [{}]!!", + provider.getKMSUrl(), ioe.getMessage()); + ex = ioe; + } catch (Exception e) { + if (e instanceof RuntimeException) { + throw (RuntimeException)e; + } else { + throw new WrapperException(e); + } + } + } + if (ex != null) { + LOG.warn("Aborting since the Request has failed with all KMS" + + " providers in the group. !!"); + throw ex; + } + throw new IOException("No providers configured !!"); + } + + private int nextIdx() { + while (true) { + int current = currentIdx.get(); + int next = (current + 1) % providers.length; + if (currentIdx.compareAndSet(current, next)) { + return current; + } + } + } + + @Override + public Token[] + addDelegationTokens(final String renewer, final Credentials credentials) + throws IOException { + return doOp(new ProviderCallable[]>() { + @Override + public Token[] call(KMSClientProvider provider) throws IOException { + return provider.addDelegationTokens(renewer, credentials); + } + }, nextIdx()); + } + + // This request is sent to all providers in the load-balancing group + @Override + public void warmUpEncryptedKeys(String... keyNames) throws IOException { + for (KMSClientProvider provider : providers) { + try { + provider.warmUpEncryptedKeys(keyNames); + } catch (IOException ioe) { + LOG.error( + "Error warming up keys for provider with url" + + "[" + provider.getKMSUrl() + "]"); + } + } + } + + // This request is sent to all providers in the load-balancing group + @Override + public void drain(String keyName) { + for (KMSClientProvider provider : providers) { + provider.drain(keyName); + } + } + + @Override + public EncryptedKeyVersion + generateEncryptedKey(final String encryptionKeyName) + throws IOException, GeneralSecurityException { + try { + return doOp(new ProviderCallable() { + @Override + public EncryptedKeyVersion call(KMSClientProvider provider) + throws IOException, GeneralSecurityException { + return provider.generateEncryptedKey(encryptionKeyName); + } + }, nextIdx()); + } catch (WrapperException we) { + throw (GeneralSecurityException) we.getCause(); + } + } + + @Override + public KeyVersion + decryptEncryptedKey(final EncryptedKeyVersion encryptedKeyVersion) + throws IOException, GeneralSecurityException { + try { + return doOp(new ProviderCallable() { + @Override + public KeyVersion call(KMSClientProvider provider) + throws IOException, GeneralSecurityException { + return provider.decryptEncryptedKey(encryptedKeyVersion); + } + }, nextIdx()); + } catch (WrapperException we) { + throw (GeneralSecurityException)we.getCause(); + } + } + + @Override + public KeyVersion getKeyVersion(final String versionName) throws IOException { + return doOp(new ProviderCallable() { + @Override + public KeyVersion call(KMSClientProvider provider) throws IOException { + return provider.getKeyVersion(versionName); + } + }, nextIdx()); + } + + @Override + public List getKeys() throws IOException { + return doOp(new ProviderCallable>() { + @Override + public List call(KMSClientProvider provider) throws IOException { + return provider.getKeys(); + } + }, nextIdx()); + } + + @Override + public Metadata[] getKeysMetadata(final String... names) throws IOException { + return doOp(new ProviderCallable() { + @Override + public Metadata[] call(KMSClientProvider provider) throws IOException { + return provider.getKeysMetadata(names); + } + }, nextIdx()); + } + + @Override + public List getKeyVersions(final String name) throws IOException { + return doOp(new ProviderCallable>() { + @Override + public List call(KMSClientProvider provider) + throws IOException { + return provider.getKeyVersions(name); + } + }, nextIdx()); + } + + @Override + public KeyVersion getCurrentKey(final String name) throws IOException { + return doOp(new ProviderCallable() { + @Override + public KeyVersion call(KMSClientProvider provider) throws IOException { + return provider.getCurrentKey(name); + } + }, nextIdx()); + } + @Override + public Metadata getMetadata(final String name) throws IOException { + return doOp(new ProviderCallable() { + @Override + public Metadata call(KMSClientProvider provider) throws IOException { + return provider.getMetadata(name); + } + }, nextIdx()); + } + + @Override + public KeyVersion createKey(final String name, final byte[] material, + final Options options) throws IOException { + return doOp(new ProviderCallable() { + @Override + public KeyVersion call(KMSClientProvider provider) throws IOException { + return provider.createKey(name, material, options); + } + }, nextIdx()); + } + + @Override + public KeyVersion createKey(final String name, final Options options) + throws NoSuchAlgorithmException, IOException { + try { + return doOp(new ProviderCallable() { + @Override + public KeyVersion call(KMSClientProvider provider) throws IOException, + NoSuchAlgorithmException { + return provider.createKey(name, options); + } + }, nextIdx()); + } catch (WrapperException e) { + throw (NoSuchAlgorithmException)e.getCause(); + } + } + @Override + public void deleteKey(final String name) throws IOException { + doOp(new ProviderCallable() { + @Override + public Void call(KMSClientProvider provider) throws IOException { + provider.deleteKey(name); + return null; + } + }, nextIdx()); + } + @Override + public KeyVersion rollNewVersion(final String name, final byte[] material) + throws IOException { + return doOp(new ProviderCallable() { + @Override + public KeyVersion call(KMSClientProvider provider) throws IOException { + return provider.rollNewVersion(name, material); + } + }, nextIdx()); + } + + @Override + public KeyVersion rollNewVersion(final String name) + throws NoSuchAlgorithmException, IOException { + try { + return doOp(new ProviderCallable() { + @Override + public KeyVersion call(KMSClientProvider provider) throws IOException, + NoSuchAlgorithmException { + return provider.rollNewVersion(name); + } + }, nextIdx()); + } catch (WrapperException e) { + throw (NoSuchAlgorithmException)e.getCause(); + } + } + + // Close all providers in the LB group + @Override + public void close() throws IOException { + for (KMSClientProvider provider : providers) { + try { + provider.close(); + } catch (IOException ioe) { + LOG.error("Error closing provider with url" + + "[" + provider.getKMSUrl() + "]"); + } + } + } + + + @Override + public void flush() throws IOException { + for (KMSClientProvider provider : providers) { + try { + provider.flush(); + } catch (IOException ioe) { + LOG.error("Error flushing provider with url" + + "[" + provider.getKMSUrl() + "]"); + } + } + } + + private static KMSClientProvider[] shuffle(KMSClientProvider[] providers) { + List list = Arrays.asList(providers); + Collections.shuffle(list); + return list.toArray(providers); + } +} diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java new file mode 100644 index 00000000000..08a3d93d2fa --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/kms/TestLoadBalancingKMSClientProvider.java @@ -0,0 +1,166 @@ +/** when(p1.getKMSUrl()).thenReturn("p1"); + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.crypto.key.kms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.net.URI; +import java.security.NoSuchAlgorithmException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.crypto.key.KeyProvider; +import org.apache.hadoop.crypto.key.KeyProvider.Options; +import org.junit.Test; +import org.mockito.Mockito; + +import com.google.common.collect.Sets; + +public class TestLoadBalancingKMSClientProvider { + + @Test + public void testCreation() throws Exception { + Configuration conf = new Configuration(); + KeyProvider kp = new KMSClientProvider.Factory().createProvider(new URI( + "kms://http@host1/kms/foo"), conf); + assertTrue(kp instanceof KMSClientProvider); + assertEquals("http://host1/kms/foo/v1/", + ((KMSClientProvider) kp).getKMSUrl()); + + kp = new KMSClientProvider.Factory().createProvider(new URI( + "kms://http@host1;host2;host3/kms/foo"), conf); + assertTrue(kp instanceof LoadBalancingKMSClientProvider); + KMSClientProvider[] providers = + ((LoadBalancingKMSClientProvider) kp).getProviders(); + assertEquals(3, providers.length); + assertEquals(Sets.newHashSet("http://host1/kms/foo/v1/", + "http://host2/kms/foo/v1/", + "http://host3/kms/foo/v1/"), + Sets.newHashSet(providers[0].getKMSUrl(), + providers[1].getKMSUrl(), + providers[2].getKMSUrl())); + + kp = new KMSClientProvider.Factory().createProvider(new URI( + "kms://http@host1;host2;host3:16000/kms/foo"), conf); + assertTrue(kp instanceof LoadBalancingKMSClientProvider); + providers = + ((LoadBalancingKMSClientProvider) kp).getProviders(); + assertEquals(3, providers.length); + assertEquals(Sets.newHashSet("http://host1:16000/kms/foo/v1/", + "http://host2:16000/kms/foo/v1/", + "http://host3:16000/kms/foo/v1/"), + Sets.newHashSet(providers[0].getKMSUrl(), + providers[1].getKMSUrl(), + providers[2].getKMSUrl())); + } + + @Test + public void testLoadBalancing() throws Exception { + Configuration conf = new Configuration(); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenReturn( + new KMSClientProvider.KMSKeyVersion("p1", "v1", new byte[0])); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenReturn( + new KMSClientProvider.KMSKeyVersion("p2", "v2", new byte[0])); + KMSClientProvider p3 = mock(KMSClientProvider.class); + when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenReturn( + new KMSClientProvider.KMSKeyVersion("p3", "v3", new byte[0])); + KeyProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] { p1, p2, p3 }, 0, conf); + assertEquals("p1", kp.createKey("test1", new Options(conf)).getName()); + assertEquals("p2", kp.createKey("test2", new Options(conf)).getName()); + assertEquals("p3", kp.createKey("test3", new Options(conf)).getName()); + assertEquals("p1", kp.createKey("test4", new Options(conf)).getName()); + } + + @Test + public void testLoadBalancingWithFailure() throws Exception { + Configuration conf = new Configuration(); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenReturn( + new KMSClientProvider.KMSKeyVersion("p1", "v1", new byte[0])); + when(p1.getKMSUrl()).thenReturn("p1"); + // This should not be retried + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new NoSuchAlgorithmException("p2")); + when(p2.getKMSUrl()).thenReturn("p2"); + KMSClientProvider p3 = mock(KMSClientProvider.class); + when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenReturn( + new KMSClientProvider.KMSKeyVersion("p3", "v3", new byte[0])); + when(p3.getKMSUrl()).thenReturn("p3"); + // This should be retried + KMSClientProvider p4 = mock(KMSClientProvider.class); + when(p4.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new IOException("p4")); + when(p4.getKMSUrl()).thenReturn("p4"); + KeyProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] { p1, p2, p3, p4 }, 0, conf); + + assertEquals("p1", kp.createKey("test4", new Options(conf)).getName()); + // Exceptions other than IOExceptions will not be retried + try { + kp.createKey("test1", new Options(conf)).getName(); + fail("Should fail since its not an IOException"); + } catch (Exception e) { + assertTrue(e instanceof NoSuchAlgorithmException); + } + assertEquals("p3", kp.createKey("test2", new Options(conf)).getName()); + // IOException will trigger retry in next provider + assertEquals("p1", kp.createKey("test3", new Options(conf)).getName()); + } + + @Test + public void testLoadBalancingWithAllBadNodes() throws Exception { + Configuration conf = new Configuration(); + KMSClientProvider p1 = mock(KMSClientProvider.class); + when(p1.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new IOException("p1")); + KMSClientProvider p2 = mock(KMSClientProvider.class); + when(p2.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new IOException("p2")); + KMSClientProvider p3 = mock(KMSClientProvider.class); + when(p3.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new IOException("p3")); + KMSClientProvider p4 = mock(KMSClientProvider.class); + when(p4.createKey(Mockito.anyString(), Mockito.any(Options.class))) + .thenThrow(new IOException("p4")); + when(p1.getKMSUrl()).thenReturn("p1"); + when(p2.getKMSUrl()).thenReturn("p2"); + when(p3.getKMSUrl()).thenReturn("p3"); + when(p4.getKMSUrl()).thenReturn("p4"); + KeyProvider kp = new LoadBalancingKMSClientProvider( + new KMSClientProvider[] { p1, p2, p3, p4 }, 0, conf); + try { + kp.createKey("test3", new Options(conf)).getName(); + fail("Should fail since all providers threw an IOException"); + } catch (Exception e) { + assertTrue(e instanceof IOException); + } + } +} diff --git a/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java b/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java index 70ba95f28d7..c5a990b58b7 100644 --- a/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java +++ b/hadoop-common-project/hadoop-kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMS.java @@ -24,9 +24,11 @@ import org.apache.hadoop.crypto.key.KeyProvider; import org.apache.hadoop.crypto.key.KeyProvider.KeyVersion; import org.apache.hadoop.crypto.key.KeyProvider.Options; import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension; +import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension; import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion; import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension; import org.apache.hadoop.crypto.key.kms.KMSClientProvider; +import org.apache.hadoop.crypto.key.kms.LoadBalancingKMSClientProvider; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.Text; import org.apache.hadoop.minikdc.MiniKdc; @@ -99,6 +101,12 @@ public class TestKMS { } } + protected KeyProvider createProvider(URI uri, Configuration conf) + throws IOException { + return new LoadBalancingKMSClientProvider( + new KMSClientProvider[] { new KMSClientProvider(uri, conf) }, conf); + } + protected T runServer(String keystore, String password, File confDir, KMSCallable callable) throws Exception { return runServer(-1, keystore, password, confDir, callable); @@ -305,7 +313,7 @@ public class TestKMS { final URI uri = createKMSUri(getKMSUrl()); if (ssl) { - KeyProvider testKp = new KMSClientProvider(uri, conf); + KeyProvider testKp = createProvider(uri, conf); ThreadGroup threadGroup = Thread.currentThread().getThreadGroup(); while (threadGroup.getParent() != null) { threadGroup = threadGroup.getParent(); @@ -335,12 +343,14 @@ public class TestKMS { doAs(user, new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - final KeyProvider kp = new KMSClientProvider(uri, conf); + final KeyProvider kp = createProvider(uri, conf); // getKeys() empty Assert.assertTrue(kp.getKeys().isEmpty()); Thread.sleep(4000); - Token[] tokens = ((KMSClientProvider)kp).addDelegationTokens("myuser", new Credentials()); + Token[] tokens = + ((KeyProviderDelegationTokenExtension.DelegationTokenExtension)kp) + .addDelegationTokens("myuser", new Credentials()); Assert.assertEquals(1, tokens.length); Assert.assertEquals("kms-dt", tokens[0].getKind().toString()); return null; @@ -348,12 +358,14 @@ public class TestKMS { }); } } else { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); // getKeys() empty Assert.assertTrue(kp.getKeys().isEmpty()); Thread.sleep(4000); - Token[] tokens = ((KMSClientProvider)kp).addDelegationTokens("myuser", new Credentials()); + Token[] tokens = + ((KeyProviderDelegationTokenExtension.DelegationTokenExtension)kp) + .addDelegationTokens("myuser", new Credentials()); Assert.assertEquals(1, tokens.length); Assert.assertEquals("kms-dt", tokens[0].getKind().toString()); } @@ -404,7 +416,7 @@ public class TestKMS { Date started = new Date(); Configuration conf = new Configuration(); URI uri = createKMSUri(getKMSUrl()); - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); // getKeys() empty Assert.assertTrue(kp.getKeys().isEmpty()); @@ -687,7 +699,7 @@ public class TestKMS { doAs("CREATE", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { Options options = new KeyProvider.Options(conf); Map attributes = options.getAttributes(); @@ -727,7 +739,7 @@ public class TestKMS { doAs("DECRYPT_EEK", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { Options options = new KeyProvider.Options(conf); Map attributes = options.getAttributes(); @@ -760,7 +772,7 @@ public class TestKMS { doAs("ROLLOVER", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { Options options = new KeyProvider.Options(conf); Map attributes = options.getAttributes(); @@ -804,7 +816,7 @@ public class TestKMS { doAs("GET", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { Options options = new KeyProvider.Options(conf); Map attributes = options.getAttributes(); @@ -836,7 +848,7 @@ public class TestKMS { final EncryptedKeyVersion ekv = doAs("GENERATE_EEK", new PrivilegedExceptionAction() { @Override public EncryptedKeyVersion run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { Options options = new KeyProvider.Options(conf); Map attributes = options.getAttributes(); @@ -861,7 +873,7 @@ public class TestKMS { doAs("ROLLOVER", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { KeyProviderCryptoExtension kpce = KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp); @@ -891,7 +903,7 @@ public class TestKMS { doAs("GENERATE_EEK", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { KeyProviderCryptoExtension kpce = KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp); @@ -964,7 +976,7 @@ public class TestKMS { new PrivilegedExceptionAction() { @Override public KeyProvider run() throws Exception { - KMSClientProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); kp.createKey("k1", new byte[16], new KeyProvider.Options(conf)); return kp; @@ -1041,7 +1053,7 @@ public class TestKMS { new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KMSClientProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); kp.createKey("k0", new byte[16], new KeyProvider.Options(conf)); @@ -1072,7 +1084,7 @@ public class TestKMS { new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KMSClientProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); kp.createKey("k3", new byte[16], new KeyProvider.Options(conf)); // Atleast 2 rollovers.. so should induce signer Exception @@ -1132,7 +1144,7 @@ public class TestKMS { doAs("client", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { kp.createKey("k", new KeyProvider.Options(conf)); Assert.fail(); @@ -1223,7 +1235,7 @@ public class TestKMS { doAs("CREATE", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { KeyProvider.KeyVersion kv = kp.createKey("k0", new KeyProvider.Options(conf)); @@ -1238,7 +1250,7 @@ public class TestKMS { doAs("DELETE", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { kp.deleteKey("k0"); } catch (Exception ex) { @@ -1251,7 +1263,7 @@ public class TestKMS { doAs("SET_KEY_MATERIAL", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { KeyProvider.KeyVersion kv = kp.createKey("k1", new byte[16], new KeyProvider.Options(conf)); @@ -1266,7 +1278,7 @@ public class TestKMS { doAs("ROLLOVER", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { KeyProvider.KeyVersion kv = kp.rollNewVersion("k1"); Assert.assertNull(kv.getMaterial()); @@ -1280,7 +1292,7 @@ public class TestKMS { doAs("SET_KEY_MATERIAL", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { KeyProvider.KeyVersion kv = kp.rollNewVersion("k1", new byte[16]); @@ -1296,7 +1308,7 @@ public class TestKMS { doAs("GET", new PrivilegedExceptionAction() { @Override public KeyVersion run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { kp.getKeyVersion("k1@0"); KeyVersion kv = kp.getCurrentKey("k1"); @@ -1313,7 +1325,7 @@ public class TestKMS { new PrivilegedExceptionAction() { @Override public EncryptedKeyVersion run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { KeyProviderCryptoExtension kpCE = KeyProviderCryptoExtension. createKeyProviderCryptoExtension(kp); @@ -1330,7 +1342,7 @@ public class TestKMS { doAs("DECRYPT_EEK", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { KeyProviderCryptoExtension kpCE = KeyProviderCryptoExtension. createKeyProviderCryptoExtension(kp); @@ -1345,7 +1357,7 @@ public class TestKMS { doAs("GET_KEYS", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { kp.getKeys(); } catch (Exception ex) { @@ -1358,7 +1370,7 @@ public class TestKMS { doAs("GET_METADATA", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); try { kp.getMetadata("k1"); kp.getKeysMetadata("k1"); @@ -1385,7 +1397,7 @@ public class TestKMS { @Override public Void run() throws Exception { try { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); KeyProvider.KeyVersion kv = kp.createKey("k2", new KeyProvider.Options(conf)); Assert.fail(); @@ -1440,12 +1452,12 @@ public class TestKMS { @Override public Void run() throws Exception { try { - KMSClientProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); KeyProvider.KeyVersion kv = kp.createKey("ck0", new KeyProvider.Options(conf)); EncryptedKeyVersion eek = - kp.generateEncryptedKey("ck0"); - kp.decryptEncryptedKey(eek); + ((CryptoExtension)kp).generateEncryptedKey("ck0"); + ((CryptoExtension)kp).decryptEncryptedKey(eek); Assert.assertNull(kv.getMaterial()); } catch (Exception ex) { Assert.fail(ex.getMessage()); @@ -1458,12 +1470,12 @@ public class TestKMS { @Override public Void run() throws Exception { try { - KMSClientProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); KeyProvider.KeyVersion kv = kp.createKey("ck1", new KeyProvider.Options(conf)); EncryptedKeyVersion eek = - kp.generateEncryptedKey("ck1"); - kp.decryptEncryptedKey(eek); + ((CryptoExtension)kp).generateEncryptedKey("ck1"); + ((CryptoExtension)kp).decryptEncryptedKey(eek); Assert.fail("admin user must not be allowed to decrypt !!"); } catch (Exception ex) { } @@ -1475,12 +1487,12 @@ public class TestKMS { @Override public Void run() throws Exception { try { - KMSClientProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); KeyProvider.KeyVersion kv = kp.createKey("ck2", new KeyProvider.Options(conf)); EncryptedKeyVersion eek = - kp.generateEncryptedKey("ck2"); - kp.decryptEncryptedKey(eek); + ((CryptoExtension)kp).generateEncryptedKey("ck2"); + ((CryptoExtension)kp).decryptEncryptedKey(eek); Assert.fail("admin user must not be allowed to decrypt !!"); } catch (Exception ex) { } @@ -1525,7 +1537,7 @@ public class TestKMS { @Override public Void run() throws Exception { try { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); KeyProvider.KeyVersion kv = kp.createKey("ck0", new KeyProvider.Options(conf)); Assert.assertNull(kv.getMaterial()); @@ -1540,7 +1552,7 @@ public class TestKMS { @Override public Void run() throws Exception { try { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); KeyProvider.KeyVersion kv = kp.createKey("ck1", new KeyProvider.Options(conf)); Assert.assertNull(kv.getMaterial()); @@ -1583,7 +1595,7 @@ public class TestKMS { boolean caughtTimeout = false; try { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); kp.getKeys(); } catch (SocketTimeoutException e) { caughtTimeout = true; @@ -1593,7 +1605,7 @@ public class TestKMS { caughtTimeout = false; try { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp) .generateEncryptedKey("a"); } catch (SocketTimeoutException e) { @@ -1604,7 +1616,7 @@ public class TestKMS { caughtTimeout = false; try { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); KeyProviderCryptoExtension.createKeyProviderCryptoExtension(kp) .decryptEncryptedKey( new KMSClientProvider.KMSEncryptedKeyVersion("a", @@ -1651,7 +1663,7 @@ public class TestKMS { UserGroupInformation.getCurrentUser(); try { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); kp.createKey(keyA, new KeyProvider.Options(conf)); } catch (IOException ex) { System.out.println(ex.getMessage()); @@ -1660,7 +1672,7 @@ public class TestKMS { doAs("client", new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); KeyProviderDelegationTokenExtension kpdte = KeyProviderDelegationTokenExtension. createKeyProviderDelegationTokenExtension(kp); @@ -1672,7 +1684,7 @@ public class TestKMS { nonKerberosUgi.addCredentials(credentials); try { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); kp.createKey(keyA, new KeyProvider.Options(conf)); } catch (IOException ex) { System.out.println(ex.getMessage()); @@ -1681,7 +1693,7 @@ public class TestKMS { nonKerberosUgi.doAs(new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); kp.createKey(keyD, new KeyProvider.Options(conf)); return null; } @@ -1767,7 +1779,7 @@ public class TestKMS { new PrivilegedExceptionAction() { @Override public KeyProvider run() throws Exception { - KMSClientProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); kp.createKey("k1", new byte[16], new KeyProvider.Options(conf)); kp.createKey("k2", new byte[16], @@ -1844,7 +1856,7 @@ public class TestKMS { clientUgi.doAs(new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - final KeyProvider kp = new KMSClientProvider(uri, conf); + final KeyProvider kp = createProvider(uri, conf); kp.createKey("kaa", new KeyProvider.Options(conf)); // authorized proxyuser @@ -1956,7 +1968,7 @@ public class TestKMS { fooUgi.doAs(new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); Assert.assertNotNull(kp.createKey("kaa", new KeyProvider.Options(conf))); return null; @@ -1970,7 +1982,7 @@ public class TestKMS { @Override public Void run() throws Exception { try { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); kp.createKey("kbb", new KeyProvider.Options(conf)); Assert.fail(); } catch (Exception ex) { @@ -1986,7 +1998,7 @@ public class TestKMS { barUgi.doAs(new PrivilegedExceptionAction() { @Override public Void run() throws Exception { - KeyProvider kp = new KMSClientProvider(uri, conf); + KeyProvider kp = createProvider(uri, conf); Assert.assertNotNull(kp.createKey("kcc", new KeyProvider.Options(conf))); return null;