diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyProviderCryptoExtension.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyProviderCryptoExtension.java
index 7e952118e22..1ecd9f6ffe3 100644
--- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyProviderCryptoExtension.java
+++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/crypto/key/KeyProviderCryptoExtension.java
@@ -389,24 +389,39 @@ public KeyVersion decryptEncryptedKey(EncryptedKeyVersion encryptedKey)
}
/**
- * Creates a KeyProviderCryptoExtension
using a given
+ * Creates a KeyProviderCryptoExtension
using a given
* {@link KeyProvider}.
*
KeyProvider
implements the
+ * If the given KeyProvider
implements the
* {@link CryptoExtension} interface the KeyProvider
itself
- * will provide the extension functionality, otherwise a default extension
+ * will provide the extension functionality.
+ * If the given KeyProvider
implements the
+ * {@link KeyProviderExtension} interface and the KeyProvider being
+ * extended by the KeyProvider
implements the
+ * {@link CryptoExtension} interface, the KeyProvider being extended will
+ * provide the extension functionality. Otherwise, a default extension
* implementation will be used.
- *
- * @param keyProvider KeyProvider
to use to create the
+ *
+ * @param keyProvider KeyProvider
to use to create the
* KeyProviderCryptoExtension
extension.
* @return a KeyProviderCryptoExtension
instance using the
* given KeyProvider
.
*/
public static KeyProviderCryptoExtension createKeyProviderCryptoExtension(
KeyProvider keyProvider) {
- CryptoExtension cryptoExtension = (keyProvider instanceof CryptoExtension)
- ? (CryptoExtension) keyProvider
- : new DefaultCryptoExtension(keyProvider);
+ CryptoExtension cryptoExtension = null;
+ if (keyProvider instanceof CryptoExtension) {
+ cryptoExtension = (CryptoExtension) keyProvider;
+ } else if (keyProvider instanceof KeyProviderExtension &&
+ ((KeyProviderExtension)keyProvider).getKeyProvider() instanceof
+ KeyProviderCryptoExtension.CryptoExtension) {
+ KeyProviderExtension keyProviderExtension =
+ (KeyProviderExtension)keyProvider;
+ cryptoExtension =
+ (CryptoExtension)keyProviderExtension.getKeyProvider();
+ } else {
+ cryptoExtension = new DefaultCryptoExtension(keyProvider);
+ }
return new KeyProviderCryptoExtension(keyProvider, cryptoExtension);
}
diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/TestKeyProviderCryptoExtension.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/TestKeyProviderCryptoExtension.java
index 62e3310173d..4712698906b 100644
--- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/TestKeyProviderCryptoExtension.java
+++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/crypto/key/TestKeyProviderCryptoExtension.java
@@ -17,9 +17,13 @@
*/
package org.apache.hadoop.crypto.key;
+import java.io.IOException;
import java.net.URI;
+import java.net.URISyntaxException;
+import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.Arrays;
+import java.util.List;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
@@ -27,6 +31,7 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
+import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -132,4 +137,207 @@ public void testEncryptDecrypt() throws Exception {
assertArrayEquals("Wrong key material from decryptEncryptedKey",
manualMaterial, apiMaterial);
}
+
+ @Test
+ public void testNonDefaultCryptoExtensionSelectionWithCachingKeyProvider()
+ throws Exception {
+ Configuration config = new Configuration();
+ KeyProvider localKp = new DummyCryptoExtensionKeyProvider(config);
+ localKp = new CachingKeyProvider(localKp, 30000, 30000);
+ EncryptedKeyVersion localEkv = getEncryptedKeyVersion(config, localKp);
+ Assert.assertEquals("dummyFakeKey@1",
+ localEkv.getEncryptionKeyVersionName());
+ }
+
+ @Test
+ public void testDefaultCryptoExtensionSelectionWithCachingKeyProvider()
+ throws Exception {
+ Configuration config = new Configuration();
+ KeyProvider localKp =
+ new UserProvider.Factory().
+ createProvider(new URI("user:///"), config);
+ localKp = new CachingKeyProvider(localKp, 30000, 30000);
+ EncryptedKeyVersion localEkv = getEncryptedKeyVersion(config, localKp);
+ Assert.assertEquals(ENCRYPTION_KEY_NAME+"@0",
+ localEkv.getEncryptionKeyVersionName());
+ }
+
+ @Test
+ public void testNonDefaultCryptoExtensionSelectionOnKeyProviderExtension()
+ throws Exception {
+ Configuration config = new Configuration();
+ KeyProvider localKp = new UserProvider.Factory().
+ createProvider(new URI("user:///"), config);
+ localKp = new DummyCachingCryptoExtensionKeyProvider(localKp, 30000, 30000);
+ EncryptedKeyVersion localEkv = getEncryptedKeyVersion(config, localKp);
+ Assert.assertEquals("dummyCachingFakeKey@1",
+ localEkv.getEncryptionKeyVersionName());
+ }
+
+ private EncryptedKeyVersion getEncryptedKeyVersion(Configuration config,
+ KeyProvider localKp)
+ throws IOException, GeneralSecurityException {
+ KeyProvider.Options localOptions = new KeyProvider.Options(config);
+ localOptions.setCipher(CIPHER);
+ localOptions.setBitLength(128);
+ KeyVersion localEncryptionKey =
+ localKp.createKey(ENCRYPTION_KEY_NAME,
+ SecureRandom.getSeed(16), localOptions);
+ KeyProviderCryptoExtension localKpExt =
+ KeyProviderCryptoExtension.
+ createKeyProviderCryptoExtension(localKp);
+ return localKpExt.generateEncryptedKey(localEncryptionKey.getName());
+ }
+
+ /**
+ * Dummy class to test that this key provider is chosen to
+ * provide CryptoExtension services over the DefaultCryptoExtension.
+ */
+ public class DummyCryptoExtensionKeyProvider extends KeyProvider
+ implements KeyProviderCryptoExtension.CryptoExtension {
+
+ private KeyProvider kp;
+ private KeyVersion kv;
+ private EncryptedKeyVersion ekv;
+
+ public DummyCryptoExtensionKeyProvider(Configuration conf) {
+ super(conf);
+ conf = new Configuration();
+ try {
+ this.kp = new UserProvider.Factory().createProvider(
+ new URI("user:///"), conf);
+ this.kv = new KeyVersion(ENCRYPTION_KEY_NAME,
+ "dummyFakeKey@1", new byte[16]);
+ this.ekv = new EncryptedKeyVersion(ENCRYPTION_KEY_NAME,
+ "dummyFakeKey@1", new byte[16], kv);
+ } catch (URISyntaxException e) {
+ fail(e.getMessage());
+ } catch (IOException e) {
+ fail(e.getMessage());
+ }
+ }
+
+ @Override
+ public void warmUpEncryptedKeys(String... keyNames) throws IOException {
+
+ }
+
+ @Override
+ public void drain(String keyName) {
+
+ }
+
+ @Override
+ public EncryptedKeyVersion generateEncryptedKey(String encryptionKeyName)
+ throws IOException, GeneralSecurityException {
+ return this.ekv;
+ }
+
+ @Override
+ public KeyVersion decryptEncryptedKey(
+ EncryptedKeyVersion encryptedKeyVersion)
+ throws IOException, GeneralSecurityException {
+ return kv;
+ }
+
+ @Override
+ public KeyVersion getKeyVersion(String versionName)
+ throws IOException {
+ return this.kp.getKeyVersion(versionName);
+ }
+
+ @Override
+ public List