NIFI-11531 Migrated tests in nifi-security-utils from Groovy to Java

- Adjusted Collection and StringUtils usage to work with Java 8

Signed-off-by: David Handermann <exceptionfactory@apache.org>

(cherry picked from commit c4f7251b23d13842618acee185e33cf6afa61317)
This commit is contained in:
Emilio Setiadarma 2023-05-08 20:33:00 -05:00 committed by exceptionfactory
parent d17ffd9cd9
commit e27d2bbb2e
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
31 changed files with 5098 additions and 6340 deletions

View File

@ -245,7 +245,7 @@ public class ScryptCipherProvider extends RandomIVPBECipherProvider {
return matcher.find(); return matcher.find();
} }
private void parseSalt(String scryptSalt, byte[] rawSalt, List<Integer> params) { void parseSalt(String scryptSalt, byte[] rawSalt, List<Integer> params) {
if (StringUtils.isEmpty(scryptSalt)) { if (StringUtils.isEmpty(scryptSalt)) {
throw new IllegalArgumentException("Cannot parse empty salt"); throw new IllegalArgumentException("Cannot parse empty salt");
} }

View File

@ -141,7 +141,7 @@ public class Scrypt {
return sb.toString(); return sb.toString();
} }
private static String encodeParams(int n, int r, int p) { public static String encodeParams(int n, int r, int p) {
return Long.toString(log2(n) << 16L | r << 8 | p, 16); return Long.toString(log2(n) << 16L | r << 8 | p, 16);
} }
@ -305,7 +305,7 @@ public class Scrypt {
* @return the derived key * @return the derived key
* @throws GeneralSecurityException when HMAC_SHA256 is not available * @throws GeneralSecurityException when HMAC_SHA256 is not available
*/ */
protected static byte[] deriveScryptKey(byte[] password, byte[] salt, int n, int r, int p, int dkLen) throws GeneralSecurityException { public static byte[] deriveScryptKey(byte[] password, byte[] salt, int n, int r, int p, int dkLen) throws GeneralSecurityException {
if (n < 2 || (n & (n - 1)) != 0) { if (n < 2 || (n & (n - 1)) != 0) {
throw new IllegalArgumentException("N must be a power of 2 greater than 1"); throw new IllegalArgumentException("N must be a power of 2 greater than 1");
} }

View File

@ -1,654 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x500.style.IETFUtils
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.Extensions
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.operator.OperatorCreationException
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequest
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder
import org.bouncycastle.util.IPAddress
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.net.ssl.SSLException
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSession
import javax.net.ssl.SSLSocket
import java.security.InvalidKeyException
import java.security.KeyPair
import java.security.KeyPairGenerator
import java.security.NoSuchAlgorithmException
import java.security.NoSuchProviderException
import java.security.SignatureException
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.concurrent.Callable
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutionException
import java.util.concurrent.Executors
import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertInstanceOf
import static org.junit.jupiter.api.Assertions.assertNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class CertificateUtilsTest {
private static final Logger logger = LoggerFactory.getLogger(CertificateUtilsTest.class)
private static final int KEY_SIZE = 2048
private static final int DAYS_IN_YEAR = 365
private static final long YESTERDAY = System.currentTimeMillis() - 24 * 60 * 60 * 1000
private static final long ONE_YEAR_FROM_NOW = System.currentTimeMillis() + 365 * 24 * 60 * 60 * 1000
private static final String SIGNATURE_ALGORITHM = "SHA256withRSA"
private static final String PROVIDER = "BC"
private static final String SUBJECT_DN = "CN=NiFi Test Server,OU=Security,O=Apache,ST=CA,C=US"
private static final String SUBJECT_DN_LEGACY_EMAIL_ATTR_RFC2985 = "CN=NiFi Test Server/emailAddress=test@apache.org,OU=Security,O=Apache,ST=CA,C=US"
private static final String ISSUER_DN = "CN=NiFi Test CA,OU=Security,O=Apache,ST=CA,C=US"
private static final List<String> SUBJECT_ALT_NAMES = ["127.0.0.1", "nifi.nifi.apache.org"]
@BeforeAll
static void setUpOnce() {
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
/**
* Generates a public/private RSA keypair using the default key size.
*
* @return the keypair
* @throws java.security.NoSuchAlgorithmException if the RSA algorithm is not available
*/
private static KeyPair generateKeyPair() throws NoSuchAlgorithmException {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA")
keyPairGenerator.initialize(KEY_SIZE)
return keyPairGenerator.generateKeyPair()
}
/**
* Generates a signed certificate using an on-demand keypair.
*
* @param dn the DN
* @return the certificate
* @throws IOException* @throws NoSuchAlgorithmException
* @throws java.security.cert.CertificateException*
* @throws java.security.NoSuchProviderException
* @throws java.security.SignatureException
* @throws OperatorCreationException
*/
private
static X509Certificate generateCertificate(String dn) throws IOException, NoSuchAlgorithmException, CertificateException,
NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
KeyPair keyPair = generateKeyPair()
return CertificateUtils.generateSelfSignedX509Certificate(keyPair, dn, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
}
/**
* Generates a certificate signed by the issuer key.
*
* @param dn the subject DN
* @param issuerDn the issuer DN
* @param issuerKey the issuer private key
* @return the certificate
* @throws IOException
* @throws NoSuchAlgorithmException
* @throws CertificateException
* @throws NoSuchProviderException*
* @throws SignatureException* @throws InvalidKeyException
* @throws OperatorCreationException
*/
private
static X509Certificate generateIssuedCertificate(String dn, X509Certificate issuer, KeyPair issuerKey) throws IOException,
NoSuchAlgorithmException, CertificateException, NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
KeyPair keyPair = generateKeyPair()
return CertificateUtils.generateIssuedCertificate(dn, keyPair.getPublic(), issuer, issuerKey, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
}
private static X509Certificate[] generateCertificateChain(String dn = SUBJECT_DN, String issuerDn = ISSUER_DN) {
final KeyPair issuerKeyPair = generateKeyPair()
final X509Certificate issuerCertificate = CertificateUtils.generateSelfSignedX509Certificate(issuerKeyPair, issuerDn, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
final X509Certificate certificate = generateIssuedCertificate(dn, issuerCertificate, issuerKeyPair)
[certificate, issuerCertificate] as X509Certificate[]
}
@SuppressWarnings("deprecation")
private static javax.security.cert.X509Certificate generateLegacyCertificate(X509Certificate x509Certificate) {
return javax.security.cert.X509Certificate.getInstance(x509Certificate.getEncoded())
}
private static Certificate generateAbstractCertificate(X509Certificate x509Certificate) {
return x509Certificate as Certificate
}
private static Date inFuture(int days) {
return new Date(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(days))
}
@Test
void testShouldConvertAbstractX509Certificate() {
// Arrange
final X509Certificate EXPECTED_NEW_CERTIFICATE = generateCertificate(SUBJECT_DN)
logger.info("Expected certificate: ${EXPECTED_NEW_CERTIFICATE.class.canonicalName} ${EXPECTED_NEW_CERTIFICATE.subjectDN.toString()} (${EXPECTED_NEW_CERTIFICATE.getSerialNumber()})")
// Form the abstract certificate
final Certificate ABSTRACT_CERTIFICATE = generateAbstractCertificate(EXPECTED_NEW_CERTIFICATE)
logger.info("Abstract certificate: ${ABSTRACT_CERTIFICATE.class.canonicalName} (?)")
// Act
X509Certificate convertedCertificate = CertificateUtils.convertAbstractX509Certificate(ABSTRACT_CERTIFICATE)
logger.info("Converted certificate: ${convertedCertificate.class.canonicalName} ${convertedCertificate.subjectDN.toString()} (${convertedCertificate.getSerialNumber()})")
// Assert
assertEquals(EXPECTED_NEW_CERTIFICATE, convertedCertificate)
}
@Test
void testShouldDetermineClientAuthStatusFromSocket() {
// Arrange
SSLSocket needSocket = [getNeedClientAuth: { -> true }] as SSLSocket
SSLSocket wantSocket = [getNeedClientAuth: { -> false }, getWantClientAuth: { -> true }] as SSLSocket
SSLSocket noneSocket = [getNeedClientAuth: { -> false }, getWantClientAuth: { -> false }] as SSLSocket
// Act
ClientAuth needClientAuthStatus = CertificateUtils.getClientAuthStatus(needSocket)
logger.info("Client auth (needSocket): ${needClientAuthStatus}")
ClientAuth wantClientAuthStatus = CertificateUtils.getClientAuthStatus(wantSocket)
logger.info("Client auth (wantSocket): ${wantClientAuthStatus}")
ClientAuth noneClientAuthStatus = CertificateUtils.getClientAuthStatus(noneSocket)
logger.info("Client auth (noneSocket): ${noneClientAuthStatus}")
// Assert
assertEquals(ClientAuth.REQUIRED, needClientAuthStatus)
assertEquals(ClientAuth.WANT, wantClientAuthStatus)
assertEquals(ClientAuth.NONE, noneClientAuthStatus)
}
@Test
void testShouldExtractClientCertificatesFromSSLServerSocketWithAnyClientAuth() {
final String EXPECTED_DN = "CN=ncm.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
logger.info("Expected DN: ${EXPECTED_DN}")
logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}")
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
// This socket is in client mode, so the peer ("target") is a server
// Create mock sockets for each possible value of ClientAuth
SSLSocket mockNoneSocket = [
getUseClientMode : { -> true },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> false },
getSession : { -> mockSession }
] as SSLSocket
SSLSocket mockNeedSocket = [
getUseClientMode : { -> true },
getNeedClientAuth: { -> true },
getWantClientAuth: { -> false },
getSession : { -> mockSession }
] as SSLSocket
SSLSocket mockWantSocket = [
getUseClientMode : { -> true },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> true },
getSession : { -> mockSession }
] as SSLSocket
// Act
def resolvedServerDNs = [mockNeedSocket, mockWantSocket, mockNoneSocket].collect { SSLSocket mockSocket ->
logger.info("Running test with socket ClientAuth setting: ${CertificateUtils.getClientAuthStatus(mockSocket)}")
String serverDN = CertificateUtils.extractPeerDNFromSSLSocket(mockNoneSocket)
logger.info("Extracted server DN: ${serverDN}")
serverDN
}
// Assert
resolvedServerDNs.stream().forEach(serverDN -> assertTrue(CertificateUtils.compareDNs(serverDN, EXPECTED_DN)))
}
@Test
void testShouldNotExtractClientCertificatesFromSSLClientSocketWithClientAuthNone() {
// Arrange
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> false }
] as SSLSocket
// Act
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}")
// Assert
assertNull(clientDN)
}
@Test
void testShouldExtractClientCertificatesFromSSLClientSocketWithClientAuthWant() {
// Arrange
final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
logger.info("Expected DN: ${EXPECTED_DN}")
logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}")
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> true },
getSession : { -> mockSession }
] as SSLSocket
// Act
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}")
// Assert
assertTrue(CertificateUtils.compareDNs(clientDN, EXPECTED_DN))
}
@Test
void testShouldHandleFailureToExtractClientCertificatesFromSSLClientSocketWithClientAuthWant() {
// Arrange
SSLSession mockSession = [getPeerCertificates: { ->
throw new SSLPeerUnverifiedException("peer not authenticated")
}] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> false },
getWantClientAuth: { -> true },
getSession : { -> mockSession }
] as SSLSocket
// Act
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}")
// Assert
assertTrue(CertificateUtils.compareDNs(clientDN, null))
}
@Test
void testShouldExtractClientCertificatesFromSSLClientSocketWithClientAuthNeed() {
// Arrange
final String EXPECTED_DN = "CN=client.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
Certificate[] certificateChain = generateCertificateChain(EXPECTED_DN)
logger.info("Expected DN: ${EXPECTED_DN}")
logger.info("Expected certificate chain: ${certificateChain.collect { (it as X509Certificate).getSubjectDN().name }.join(" issued by ")}")
SSLSession mockSession = [getPeerCertificates: { -> certificateChain }] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> true },
getWantClientAuth: { -> false },
getSession : { -> mockSession }
] as SSLSocket
// Act
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}")
// Assert
assertTrue(CertificateUtils.compareDNs(clientDN, EXPECTED_DN))
}
@Test
void testShouldHandleFailureToExtractClientCertificatesFromSSLClientSocketWithClientAuthNeed() {
// Arrange
SSLSession mockSession = [getPeerCertificates: { ->
throw new SSLPeerUnverifiedException("peer not authenticated")
}] as SSLSession
// This socket is in server mode, so the peer ("target") is a client
SSLSocket mockSocket = [
getUseClientMode : { -> false },
getNeedClientAuth: { -> true },
getWantClientAuth: { -> false },
getSession : { -> mockSession }
] as SSLSocket
// Act
CertificateException ce = assertThrows(CertificateException.class,
() -> CertificateUtils.extractPeerDNFromSSLSocket(mockSocket))
// Assert
assertTrue(ce.getMessage().contains("peer not authenticated"))
}
@Test
void testShouldCompareDNs() {
// Arrange
final String DN_1_ORDERED = "CN=test1.nifi.apache.org, OU=Apache NiFi, O=Apache, ST=California, C=US"
logger.info("DN 1 Ordered : ${DN_1_ORDERED}")
final String DN_1_REVERSED = DN_1_ORDERED.split(", ").reverse().join(", ")
logger.info("DN 1 Reversed: ${DN_1_REVERSED}")
final String DN_2_ORDERED = "CN=test2.nifi.apache.org, OU=Apache NiFi, O=Apache, ST=California, C=US"
logger.info("DN 2 Ordered : ${DN_2_ORDERED}")
final String DN_2_REVERSED = DN_2_ORDERED.split(", ").reverse().join(", ")
logger.info("DN 2 Reversed: ${DN_2_REVERSED}")
// Act
// True
boolean dn1MatchesSelf = CertificateUtils.compareDNs(DN_1_ORDERED, DN_1_ORDERED)
logger.matches("DN 1, DN 1: ${dn1MatchesSelf}")
boolean dn1MatchesReversed = CertificateUtils.compareDNs(DN_1_ORDERED, DN_1_REVERSED)
logger.matches("DN 1, DN 1 (R): ${dn1MatchesReversed}")
boolean emptyMatchesEmpty = CertificateUtils.compareDNs("", "")
logger.matches("empty, empty: ${emptyMatchesEmpty}")
boolean nullMatchesNull = CertificateUtils.compareDNs(null, null)
logger.matches("null, null: ${nullMatchesNull}")
// False
boolean dn1MatchesDn2 = CertificateUtils.compareDNs(DN_1_ORDERED, DN_2_ORDERED)
logger.matches("DN 1, DN 2: ${dn1MatchesDn2}")
boolean dn1MatchesDn2Reversed = CertificateUtils.compareDNs(DN_1_ORDERED, DN_2_REVERSED)
logger.matches("DN 1, DN 2 (R): ${dn1MatchesDn2Reversed}")
boolean dn1MatchesEmpty = CertificateUtils.compareDNs(DN_1_ORDERED, "")
logger.matches("DN 1, empty: ${dn1MatchesEmpty}")
// Assert
assertTrue(dn1MatchesReversed)
assertTrue(emptyMatchesEmpty)
assertTrue(nullMatchesNull)
assertFalse(dn1MatchesDn2)
assertFalse(dn1MatchesDn2Reversed)
assertFalse(dn1MatchesEmpty)
}
@Test
void testGetCommonName(){
String dn1 = "CN=testDN,O=testOrg"
String dn2 = "O=testDN,O=testOrg"
assertEquals("testDN", CertificateUtils.getCommonName(dn1))
assertNull(CertificateUtils.getCommonName(dn2))
}
@Test
void testShouldGenerateSelfSignedCert() throws Exception {
String dn = "CN=testDN,O=testOrg"
int days = 365
X509Certificate x509Certificate = CertificateUtils.generateSelfSignedX509Certificate(generateKeyPair(), dn, SIGNATURE_ALGORITHM, days)
Date notAfter = x509Certificate.getNotAfter()
assertTrue(notAfter.after(inFuture(days - 1)))
assertTrue(notAfter.before(inFuture(days + 1)))
Date notBefore = x509Certificate.getNotBefore()
assertTrue(notBefore.after(inFuture(-1)))
assertTrue(notBefore.before(inFuture(1)))
assertEquals(dn, x509Certificate.getIssuerX500Principal().getName())
assertEquals(SIGNATURE_ALGORITHM.toUpperCase(), x509Certificate.getSigAlgName().toUpperCase())
assertEquals("RSA", x509Certificate.getPublicKey().getAlgorithm())
assertEquals(1, x509Certificate.getSubjectAlternativeNames().size())
GeneralName gn = x509Certificate.getSubjectAlternativeNames().iterator().next()
assertEquals(GeneralName.dNSName, gn.getTagNo())
assertEquals("testDN", gn.getName().toString())
x509Certificate.checkValidity()
}
@Test
void testIssueCert() throws Exception {
int days = 365
KeyPair issuerKeyPair = generateKeyPair()
X509Certificate issuer = CertificateUtils.generateSelfSignedX509Certificate(issuerKeyPair, "CN=testCa,O=testOrg", SIGNATURE_ALGORITHM, days)
String dn = "CN=testIssued, O=testOrg"
KeyPair keyPair = generateKeyPair()
X509Certificate x509Certificate = CertificateUtils.generateIssuedCertificate(dn, keyPair.getPublic(), issuer, issuerKeyPair, SIGNATURE_ALGORITHM, days)
assertEquals(dn, x509Certificate.getSubjectX500Principal().toString())
assertEquals(issuer.getSubjectX500Principal().toString(), x509Certificate.getIssuerX500Principal().toString())
assertEquals(keyPair.getPublic(), x509Certificate.getPublicKey())
Date notAfter = x509Certificate.getNotAfter()
assertTrue(notAfter.after(inFuture(days - 1)))
assertTrue(notAfter.before(inFuture(days + 1)))
Date notBefore = x509Certificate.getNotBefore()
assertTrue(notBefore.after(inFuture(-1)))
assertTrue(notBefore.before(inFuture(1)))
assertEquals(SIGNATURE_ALGORITHM.toUpperCase(), x509Certificate.getSigAlgName().toUpperCase())
assertEquals("RSA", x509Certificate.getPublicKey().getAlgorithm())
x509Certificate.verify(issuerKeyPair.getPublic())
}
@Test
void reorderShouldPutElementsInCorrectOrder() {
String cn = "CN=testcn"
String l = "L=testl"
String st = "ST=testst"
String o = "O=testo"
String ou = "OU=testou"
String c = "C=testc"
String street = "STREET=teststreet"
String dc = "DC=testdc"
String uid = "UID=testuid"
String surname = "SURNAME=testsurname"
String initials = "INITIALS=testinitials"
String givenName = "GIVENNAME=testgivenname"
assertEquals("$cn,$l,$st,$o,$ou,$c,$street,$dc,$uid,$surname,$givenName,$initials".toString(),
CertificateUtils.reorderDn("$surname,$st,$o,$initials,$givenName,$uid,$street,$c,$cn,$ou,$l,$dc"))
}
@Test
void testUniqueSerialNumbers() {
def running = new AtomicBoolean(true)
def executorService = Executors.newCachedThreadPool()
def serialNumbers = Collections.newSetFromMap(new ConcurrentHashMap())
try {
def futures = new ArrayList<Future>()
for (int i = 0; i < 8; i++) {
futures.add(executorService.submit(new Callable<Integer>() {
@Override
Integer call() throws Exception {
int count = 0
while (running.get()) {
def before = System.currentTimeMillis()
def serialNumber = CertificateUtils.getUniqueSerialNumber()
def after = System.currentTimeMillis()
def serialNumberMillis = serialNumber.shiftRight(32)
assertTrue(serialNumberMillis >= before)
assertTrue(serialNumberMillis <= after)
assertTrue(serialNumbers.add(serialNumber))
count++
}
return count
}
}))
}
Thread.sleep(1000)
running.set(false)
def totalRuns = 0
for (int i = 0; i < futures.size(); i++) {
try {
def numTimes = futures.get(i).get()
logger.info("future $i executed $numTimes times")
totalRuns += numTimes
} catch (ExecutionException e) {
throw e.getCause()
}
}
logger.info("Generated ${serialNumbers.size()} unique serial numbers")
assertEquals(totalRuns, serialNumbers.size())
} finally {
executorService.shutdown()
}
}
@Test
void testShouldGenerateIssuedCertificateWithSans() {
// Arrange
final String SUBJECT_DN = "CN=localhost"
final List<String> SANS = ["127.0.0.1", "nifi.nifi.apache.org"]
logger.info("Creating a certificate with subject: ${SUBJECT_DN} and SAN: ${SANS}")
final KeyPair subjectKeyPair = generateKeyPair()
final KeyPair issuerKeyPair = generateKeyPair()
final X509Certificate issuerCertificate = CertificateUtils.generateSelfSignedX509Certificate(issuerKeyPair, ISSUER_DN, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
// Form the SANS into GeneralName instances and populate the container with the array
def gns = SANS.collect { String san ->
new GeneralName(GeneralName.dNSName, san)
}
def generalNames = new GeneralNames(gns as GeneralName[])
logger.info("Created GeneralNames object: ${generalNames.names*.toString()}")
// Form the Extensions object
ExtensionsGenerator extensionsGenerator = new ExtensionsGenerator()
extensionsGenerator.addExtension(Extension.subjectAlternativeName, false, generalNames)
Extensions extensions = extensionsGenerator.generate()
logger.info("Generated extensions object: ${extensions.oids()*.toString()}")
// Act
X509Certificate certificate = CertificateUtils.generateIssuedCertificate(SUBJECT_DN, subjectKeyPair.public, extensions, issuerCertificate, issuerKeyPair, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
logger.info("Issued certificate with subject: ${certificate.getSubjectDN().name} and SAN: ${certificate.getSubjectAlternativeNames().join(",")}")
// Assert
assertEquals(SUBJECT_DN, certificate.getSubjectDN().name)
assertEquals(SANS.size(), certificate.getSubjectAlternativeNames().size())
assertTrue(certificate.getSubjectAlternativeNames()*.last().containsAll(SANS))
}
@Test
void testShouldDetectTlsErrors() {
// Arrange
final String msg = "Test exception"
// SSLPeerUnverifiedException isn't specifically defined in the method, but is a subclass of SSLException so it should be caught
List<Throwable> directErrors = [new TlsException(msg), new SSLPeerUnverifiedException(msg), new CertificateException(msg), new SSLException(msg)]
List<Throwable> causedErrors = directErrors.collect { Throwable cause -> new Exception(msg, cause) } + [
new Exception(msg,
new Exception("Nested $msg",
new Exception("Double nested $msg",
new TlsException("Triple nested $msg"))))]
List<Throwable> unrelatedErrors = [new Exception(msg), new IllegalArgumentException(msg), new NullPointerException(msg)]
// Act
def directResults = directErrors.collect { Throwable e -> CertificateUtils.isTlsError(e) }
def causedResults = causedErrors.collect { Throwable e -> CertificateUtils.isTlsError(e) }
def unrelatedResults = unrelatedErrors.collect { Throwable e -> CertificateUtils.isTlsError(e) }
logger.info("Direct results: ${directResults}")
logger.info("Caused results: ${causedResults}")
logger.info("Unrelated results: ${unrelatedResults}")
// Assert
assertTrue(directResults.every())
assertTrue(causedResults.every())
assertFalse(unrelatedResults.any())
}
@Test
void testGetExtensionsFromCSR() {
// Arrange
KeyPairGenerator generator = KeyPairGenerator.getInstance("RSA")
KeyPair keyPair = generator.generateKeyPair()
Extensions sanExtensions = createDomainAlternativeNamesExtensions(SUBJECT_ALT_NAMES, SUBJECT_DN)
JcaPKCS10CertificationRequestBuilder jcaPKCS10CertificationRequestBuilder = new JcaPKCS10CertificationRequestBuilder(new X500Name(SUBJECT_DN), keyPair.getPublic())
jcaPKCS10CertificationRequestBuilder.addAttribute(PKCSObjectIdentifiers.pkcs_9_at_extensionRequest, sanExtensions)
JcaContentSignerBuilder jcaContentSignerBuilder = new JcaContentSignerBuilder("SHA256WITHRSA")
JcaPKCS10CertificationRequest jcaPKCS10CertificationRequest = new JcaPKCS10CertificationRequest(jcaPKCS10CertificationRequestBuilder.build(jcaContentSignerBuilder.build(keyPair.getPrivate())))
// Act
Extensions extensions = CertificateUtils.getExtensionsFromCSR(jcaPKCS10CertificationRequest)
// Assert
assert(extensions.equivalent(sanExtensions))
}
@Test
void testExtractUserNameFromDN() {
String expected = "NiFi Test Server"
assertEquals(CertificateUtils.extractUsername(SUBJECT_DN), expected)
assertEquals(CertificateUtils.extractUsername(SUBJECT_DN_LEGACY_EMAIL_ATTR_RFC2985), expected)
}
// Using this directly from tls-toolkit results in a dependency loop, so it's added here for testing purposes.
private static Extensions createDomainAlternativeNamesExtensions(List<String> domainAlternativeNames, String requestedDn) throws IOException {
List<GeneralName> namesList = new ArrayList<>()
try {
final String cn = IETFUtils.valueToString(new X500Name(requestedDn).getRDNs(BCStyle.CN)[0].getFirst().getValue())
namesList.add(new GeneralName(GeneralName.dNSName, cn))
} catch (Exception e) {
throw new IOException("Failed to extract CN from request DN: " + requestedDn, e)
}
if (domainAlternativeNames != null) {
for (String alternativeName : domainAlternativeNames) {
namesList.add(new GeneralName(IPAddress.isValid(alternativeName) ? GeneralName.iPAddress : GeneralName.dNSName, alternativeName))
}
}
GeneralNames subjectAltNames = new GeneralNames(namesList.toArray([] as GeneralName[]))
ExtensionsGenerator extGen = new ExtensionsGenerator()
extGen.addExtension(Extension.subjectAlternativeName, false, subjectAltNames)
return extGen.generate()
}
}

View File

@ -1,509 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.apache.commons.codec.binary.Base64
import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.EncryptionMethod
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.crypto.Cipher
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
import java.nio.charset.StandardCharsets
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class Argon2CipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(Argon2CipherProviderGroovyTest.class)
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess"
private static List<EncryptionMethod> strongKDFEncryptionMethods
private static final int DEFAULT_KEY_LENGTH = 128
private final String SALT_HEX = "0123456789ABCDEFFEDCBA9876543210"
private static ArrayList<Integer> AES_KEY_LENGTHS
RandomIVPBECipherProvider cipherProvider
private final IntRange FULL_SALT_LENGTH_RANGE= (49..53)
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
strongKDFEncryptionMethods = EncryptionMethod.values().findAll { it.isCompatibleWithStrongKDFs() }
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
AES_KEY_LENGTHS = [128, 192, 256]
}
@BeforeEach
void setUp() throws Exception {
// Very fast parameters to test for correctness rather than production values
cipherProvider = new Argon2CipherProvider(1024, 1, 3)
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testArgon2ShouldSupportExternalCompatibility() throws Exception {
// Arrange
// Default values are hashLength = 32, memory = 1024, parallelism = 1, iterations = 3, but the provided salt will contain the parameters used
cipherProvider = new Argon2CipherProvider()
final String PLAINTEXT = "This is a plaintext message."
final String PASSWORD = "thisIsABadPassword"
final int hashLength = 256
// These values can be generated by running `$ ./openssl_argon2.rb` in the terminal
final byte[] SALT = Hex.decodeHex("68d29a1d8021f45954333767358a2492" as char[])
logger.info("Expected salt: ${Hex.encodeHexString(SALT)}")
final byte[] IV = Hex.decodeHex("808590f35f9fba14dbda9c2bb2b76a79" as char[])
final String CIPHER_TEXT = "d672412857916880c79d573aa4f9d4971b85f07438d6f62f38a0e31314caa2e5"
logger.sanity("Ruby cipher text: ${CIPHER_TEXT}")
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT as char[])
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Sanity check
String rubyKeyHex = "8caf581795886d38f0c605e3d674f4961c658ee3625a8e8868be36c902d234ef"
logger.sanity("Using key: ${rubyKeyHex}")
logger.sanity("Using IV: ${Hex.encodeHexString(IV)}")
Cipher rubyCipher = Cipher.getInstance(encryptionMethod.algorithm, "BC")
def rubyKey = new SecretKeySpec(Hex.decodeHex(rubyKeyHex as char[]), "AES")
def ivSpec = new IvParameterSpec(IV)
rubyCipher.init(Cipher.ENCRYPT_MODE, rubyKey, ivSpec)
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes)
logger.sanity("Created cipher text: ${Hex.encodeHexString(rubyCipherBytes)}")
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec)
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(rubyCipherBytes))
logger.sanity("Decrypted generated cipher text successfully")
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(cipherBytes))
logger.sanity("Decrypted external cipher text successfully")
// $argon2id$v=19$m=memory,t=iterations,p=parallelism$saltB64$hashB64
final String FULL_HASH = "\$argon2id\$v=19\$m=256,t=3,p=1\$aNKaHYAh9FlUMzdnNYokkg\$jK9YF5WIbTjwxgXj1nT0lhxljuNiWo6IaL42yQLSNO8"
logger.info("Full Hash: ${FULL_HASH}")
final String FULL_SALT = FULL_HASH[0..<FULL_HASH.lastIndexOf("\$")]
logger.info("Full salt: ${FULL_SALT}")
final String[] hashComponents = FULL_HASH.split("\\\$")
logger.info("hashComponents: ${Arrays.toString(hashComponents)}")
Map<String, String> saltParams = hashComponents[3].split(",").collectEntries { String pair ->
pair.split("=")
}
logger.info("saltParams: ${saltParams}")
def saltB64 = hashComponents[4]
byte[] salt = Base64.decodeBase64(saltB64)
logger.info("Salt: ${Hex.encodeHexString(salt)}")
assertArrayEquals(SALT, salt)
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
logger.info("External cipher text: ${CIPHER_TEXT} ${cipherBytes.length}")
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, FULL_SALT.bytes, IV, hashLength, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
@Test
void testGetCipherShouldRejectInvalidIV() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
final def INVALID_IVS = (0..15).collect { int length -> new byte[length] }
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
INVALID_IVS.each { byte[] badIV ->
logger.info("IV: ${Hex.encodeHexString(badIV)} ${badIV.length}")
// Encrypt should print a warning about the bad IV but overwrite it
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, true)
// Decrypt should fail
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false))
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@Test
void testGetCipherWithExternalIVShouldBeInternallyConsistent() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("01" * 16 as char[])
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final int LONG_KEY_LENGTH = 256
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, LONG_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, iv, LONG_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherShouldNotAcceptInvalidSalts() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword"
final def INVALID_SALTS = ['argon2', '$3a$11$', 'x', '$2a$10$']
final LENGTH_MESSAGE = "The raw salt must be greater than or equal to 8 bytes"
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true))
logger.expected(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains(LENGTH_MESSAGE))
}
}
@Test
void testGetCipherShouldHandleUnformattedSalts() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword"
final def RECOVERABLE_SALTS = ['$ab$00$acbdefghijklmnopqrstuv', '$4$1$1$0123456789abcdef', '$400$1$1$abcdefghijklmnopqrstuv']
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
RECOVERABLE_SALTS.each { String salt ->
logger.info("Checking salt ${salt}")
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true)
// Assert
assertNotNull(cipher)
}
}
@Test
void testGetCipherShouldRejectEmptySalt() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword"
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true))
logger.expected(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains("The salt cannot be empty. To generate a salt, use Argon2CipherProvider#generateSalt()"))
}
@Test
void testGenerateSaltShouldProvideValidSalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new Argon2CipherProvider()
// Act
byte[] saltBytes = cipherProvider.generateSalt()
logger.info("Generated salt ${Hex.encodeHexString(saltBytes)}")
String fullSalt = new String(saltBytes, StandardCharsets.UTF_8)
logger.info("Generated salt (${saltBytes.length}): ${fullSalt}".toString())
def rawSaltB64 = (fullSalt =~ /\$([\w\+\/]+)\$?$/)[0][1]
logger.info("Extracted B64 raw salt (${rawSaltB64.size()}): ${rawSaltB64}".toString())
byte[] rawSaltBytes = Base64.decodeBase64(rawSaltB64)
// Assert
boolean isValidFormattedSalt = cipherProvider.isArgon2FormattedSalt(fullSalt)
logger.info("Salt is Argon2 format: ${isValidFormattedSalt}")
assertTrue(isValidFormattedSalt)
boolean fullSaltIsValidLength = FULL_SALT_LENGTH_RANGE.contains(saltBytes.length)
logger.info("Salt length (${fullSalt.length()}) in valid range (${FULL_SALT_LENGTH_RANGE})")
assertTrue(fullSaltIsValidLength)
byte [] notExpected = new byte[16]
Arrays.fill(notExpected, 0x00 as byte)
assertFalse(Arrays.equals(notExpected, rawSaltBytes))
}
@Test
void testGetCipherForDecryptShouldRequireIV() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("00" * 16 as char[])
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
logger.expected(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@Test
void testGetCipherShouldAcceptValidKeyLengths() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("01" * 16 as char[])
final def VALID_KEY_LENGTHS = AES_KEY_LENGTHS
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
VALID_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherShouldNotAcceptInvalidKeyLengths() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("00" * 16 as char[])
final def INVALID_KEY_LENGTHS = [-1, 40, 64, 112, 512]
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
INVALID_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
logger.expected(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
}
}
@Test
void testArgon2ShouldNotAcceptInvalidPassword() {
// Arrange
String badPassword = ""
byte[] salt = [0x01 as byte] * 16
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true))
// Assert
assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"))
}
@Test
void testShouldParseSalt() throws Exception {
// Arrange
cipherProvider = cipherProvider as Argon2CipherProvider
final byte[] EXPECTED_RAW_SALT = Hex.decodeHex("8622b26906d9c900660a60f5cc673233" as char[])
final int EXPECTED_MEMORY = 1024
final int EXPECTED_PARALLELISM = 4
final int EXPECTED_ITERATIONS = 1
final String FORMATTED_SALT = "\$argon2id\$v=19\$m=1024,t=4,p=1\$hiKyaQbZyQBmCmD1zGcyMw"
logger.info("Using salt: ${FORMATTED_SALT}")
byte[] rawSalt = new byte[16]
def params = []
// Act
cipherProvider.parseSalt(FORMATTED_SALT, rawSalt, params)
// Assert
assertArrayEquals(EXPECTED_RAW_SALT, rawSalt)
assertEquals(EXPECTED_MEMORY, params[0])
assertEquals(EXPECTED_PARALLELISM, params[1])
assertEquals(EXPECTED_ITERATIONS, params[2])
}
@Test
void testShouldRejectInvalidSalt() throws Exception {
// Arrange
cipherProvider = cipherProvider as Argon2CipherProvider
final String FULL_HASH = "\$argon2id\$v=19\$m=1024,t=4,p=1\$hiKyaQbZyQBmCmD1zGcyMw\$rc+ec+/hQeBcwzjH+OEmUtaTUqhZYKN4ZKJtWzFZYjQ"
logger.info("Using salt: ${FULL_HASH}")
byte[] rawSalt = new byte[16]
List<Integer> params = []
// Act
boolean isValid = cipherProvider.isArgon2FormattedSalt(FULL_HASH)
logger.info("Argon2 formatted salt: ${isValid}")
// Assert
assertFalse(isValid)
}
@Test
void testShouldExtractSalt() throws Exception {
// Arrange
cipherProvider = cipherProvider as Argon2CipherProvider
final byte[] EXPECTED_RAW_SALT = Hex.decodeHex("8622b26906d9c900660a60f5cc673233" as char[])
final String FORMATTED_SALT = "\$argon2id\$v=19\$m=1024,t=4,p=1\$hiKyaQbZyQBmCmD1zGcyMw"
logger.info("Using salt: ${FORMATTED_SALT}")
byte[] rawSalt
// Act
rawSalt = cipherProvider.extractRawSaltFromArgon2Salt(FORMATTED_SALT)
logger.info("rawSalt: ${Hex.encodeHexString(rawSalt)}")
// Assert
assertArrayEquals(EXPECTED_RAW_SALT, rawSalt)
}
}

View File

@ -1,458 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.util.encoders.Hex
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.nio.charset.StandardCharsets
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class Argon2SecureHasherTest {
private static final Logger logger = LoggerFactory.getLogger(Argon2SecureHasherTest.class)
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Test
void testShouldBeDeterministicWithStaticSalt() {
// Arrange
int hashLength = 32
int memory = 8
int parallelism = 4
int iterations = 4
logger.info("Generating Argon2 hash for hash length: ${hashLength} B, mem: ${memory} KiB, parallelism: ${parallelism}, iterations: ${iterations}")
int testIterations = 10
byte[] inputBytes = "This is a sensitive value".bytes
final String EXPECTED_HASH_HEX = "a73a471f51b2900901a00b81e770b9c1dfc595602bb7aec64cd27754a4174919"
Argon2SecureHasher a2sh = new Argon2SecureHasher(hashLength, memory, parallelism, iterations)
def results = []
// Act
testIterations.times { int i ->
byte[] hash = a2sh.hashRaw(inputBytes)
String hashHex = new String(Hex.encode(hash))
logger.info("Generated hash: ${hashHex}")
results << hashHex
}
// Assert
results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
}
@Test
void testShouldBeDifferentWithRandomSalt() {
// Arrange
int hashLength = 32
int memory = 8
int parallelism = 4
int iterations = 4
logger.info("Generating Argon2 hash for hash length: ${hashLength} B, mem: ${memory} KiB, parallelism: ${parallelism}, iterations: ${iterations}")
int testIterations = 10
byte[] inputBytes = "This is a sensitive value".bytes
final String EXPECTED_HASH_HEX = "a73a471f51b2900901a00b81e770b9c1dfc595602bb7aec64cd27754a4174919"
Argon2SecureHasher a2sh = new Argon2SecureHasher(hashLength, memory, parallelism, iterations, 16)
def results = []
// Act
testIterations.times { int i ->
byte[] hash = a2sh.hashRaw(inputBytes)
String hashHex = Hex.encode(hash)
logger.info("Generated hash: ${hashHex}")
results << hashHex
}
// Assert
assertTrue(results.unique().size() == results.size())
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
}
@Test
void testShouldHandleArbitrarySalt() {
// Arrange
int hashLength = 32
int memory = 8
int parallelism = 4
int iterations = 4
logger.info("Generating Argon2 hash for hash length: ${hashLength} B, mem: ${memory} KiB, parallelism: ${parallelism}, iterations: ${iterations}")
def input = "This is a sensitive value"
byte[] inputBytes = input.bytes
final String EXPECTED_HASH_HEX = "a73a471f51b2900901a00b81e770b9c1dfc595602bb7aec64cd27754a4174919"
logger.info("Expected Hash Hex length: ${EXPECTED_HASH_HEX.length()}")
final String EXPECTED_HASH_BASE64 = "pzpHH1GykAkBoAuB53C5wd/FlWArt67GTNJ3VKQXSRk"
final byte[] EXPECTED_HASH_BYTES = Hex.decode(EXPECTED_HASH_HEX)
// Static salt instance
Argon2SecureHasher staticSaltHasher = new Argon2SecureHasher(hashLength, memory, parallelism, iterations)
Argon2SecureHasher arbitrarySaltHasher = new Argon2SecureHasher(hashLength, memory, parallelism, iterations, 16)
final byte[] STATIC_SALT = AbstractSecureHasher.STATIC_SALT
final String DIFFERENT_STATIC_SALT = "Diff Static Salt"
// Act
byte[] staticSaltHash = staticSaltHasher.hashRaw(inputBytes)
byte[] arbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, STATIC_SALT)
byte[] differentArbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, DIFFERENT_STATIC_SALT.getBytes(StandardCharsets.UTF_8))
byte[] differentSaltHash = arbitrarySaltHasher.hashRaw(inputBytes)
String staticSaltHashHex = staticSaltHasher.hashHex(input)
String arbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8))
String differentArbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, DIFFERENT_STATIC_SALT)
String differentSaltHashHex = arbitrarySaltHasher.hashHex(input)
String staticSaltHashBase64 = staticSaltHasher.hashBase64(input)
String arbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8))
String differentArbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, DIFFERENT_STATIC_SALT)
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
}
@Test
void testShouldValidateArbitrarySalt() {
// Arrange
int hashLength = 32
int memory = 8
int parallelism = 4
int iterations = 4
logger.info("Generating Argon2 hash for hash length: ${hashLength} B, mem: ${memory} KiB, parallelism: ${parallelism}, iterations: ${iterations}")
def input = "This is a sensitive value"
byte[] inputBytes = input.bytes
// Static salt instance
Argon2SecureHasher secureHasher = new Argon2SecureHasher(hashLength, memory, parallelism, iterations, 16)
final byte[] STATIC_SALT = "bad_sal".bytes
// Act
assertThrows(IllegalArgumentException.class, { ->
new Argon2SecureHasher(hashLength, memory, parallelism, iterations, 7) })
assertThrows(RuntimeException.class, { -> secureHasher.hashRaw(inputBytes, STATIC_SALT) })
assertThrows(RuntimeException.class, { -> secureHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8)) })
assertThrows(RuntimeException.class, { -> secureHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8)) })
}
@Test
void testShouldFormatHex() {
// Arrange
String input = "This is a sensitive value"
final String EXPECTED_HASH_HEX = "0c2920c52f28e0a2c77d006ec6138c8dc59580881468b85541cf886abdebcf18"
Argon2SecureHasher a2sh = new Argon2SecureHasher(32, 4096, 1, 3)
// Act
String hashHex = a2sh.hashHex(input)
logger.info("Generated hash: ${hashHex}")
// Assert
assertEquals(EXPECTED_HASH_HEX, hashHex)
}
@Test
void testShouldFormatBase64() {
// Arrange
String input = "This is a sensitive value"
final String EXPECTED_HASH_B64 = "DCkgxS8o4KLHfQBuxhOMjcWVgIgUaLhVQc+Iar3rzxg"
Argon2SecureHasher a2sh = new Argon2SecureHasher(32, 4096, 1, 3)
// Act
String hashB64 = a2sh.hashBase64(input)
logger.info("Generated hash: ${hashB64}")
// Assert
assertEquals(EXPECTED_HASH_B64, hashB64)
}
@Test
void testShouldHandleNullInput() {
// Arrange
List<String> inputs = [null, ""]
final String EXPECTED_HASH_HEX = "8e5625a66b94ed9d31c1496d7f9ff49249cf05d6753b50ba0e2bf2a1108973dd"
final String EXPECTED_HASH_B64 = "jlYlpmuU7Z0xwUltf5/0kknPBdZ1O1C6DivyoRCJc90"
Argon2SecureHasher a2sh = new Argon2SecureHasher(32, 4096, 1, 3)
def hexResults = []
def b64Results = []
// Act
inputs.each { String input ->
String hashHex = a2sh.hashHex(input)
logger.info("Generated hash: ${hashHex}")
hexResults << hashHex
String hashB64 = a2sh.hashBase64(input)
logger.info("Generated hash: ${hashB64}")
b64Results << hashB64
}
// Assert
hexResults.forEach(hexResult -> assertEquals(EXPECTED_HASH_HEX, hexResult))
b64Results.forEach(b64Result -> assertEquals(EXPECTED_HASH_B64, b64Result))
}
/**
* This test can have the minimum time threshold updated to determine if the performance
* is still sufficient compared to the existing threat model.
*/
@EnabledIfSystemProperty(named = "nifi.test.performance", matches = "true")
@Test
void testDefaultCostParamsShouldBeSufficient() {
// Arrange
int testIterations = 100 //_000
byte[] inputBytes = "This is a sensitive value".bytes
Argon2SecureHasher a2sh = new Argon2SecureHasher(16, 2**16, 8, 5)
def results = []
def resultDurations = []
// Act
testIterations.times { int i ->
long startNanos = System.nanoTime()
byte[] hash = a2sh.hashRaw(inputBytes)
long endNanos = System.nanoTime()
long durationNanos = endNanos - startNanos
String hashHex = Hex.encode(hash)
logger.info("Generated hash: ${hashHex} in ${durationNanos} ns")
results << hashHex
resultDurations << durationNanos
}
def milliDurations = [resultDurations.min(), resultDurations.max(), resultDurations.sum()/resultDurations.size()].collect { it / 1_000_000 }
logger.info("Min/Max/Avg durations in ms: ${milliDurations}")
// Assert
final long MIN_DURATION_NANOS = 500_000_000 // 500 ms
assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
}
@Test
void testShouldVerifyHashLengthBoundary() throws Exception {
// Arrange
final int hashLength = 128
// Act
boolean valid = Argon2SecureHasher.isHashLengthValid(hashLength)
// Assert
assertTrue(valid)
}
@Test
void testShouldFailHashLengthBoundary() throws Exception {
// Arrange
def hashLengths = [-8, 0, 1, 2]
// Act
def results = hashLengths.collect { hashLength ->
def isValid = Argon2SecureHasher.isHashLengthValid(hashLength)
[hashLength, isValid]
}
// Assert
results.each { hashLength, isHashLengthValid ->
logger.info("For hashLength value ${hashLength}, hashLength is ${isHashLengthValid ? "valid" : "invalid"}")
assertFalse(isHashLengthValid)
}
}
@Test
void testShouldVerifyMemorySizeBoundary() throws Exception {
// Arrange
final int memory = 2048
// Act
boolean valid = Argon2SecureHasher.isMemorySizeValid(memory)
// Assert
assertTrue(valid)
}
@Test
void testShouldFailMemorySizeBoundary() throws Exception {
// Arrange
def memorySizes = [-12, 0, 1, 6]
// Act
def results = memorySizes.collect { memory ->
def isValid = Argon2SecureHasher.isMemorySizeValid(memory)
[memory, isValid]
}
// Assert
results.each { memory, isMemorySizeValid ->
logger.info("For memory size ${memory}, memory is ${isMemorySizeValid ? "valid" : "invalid"}")
assertFalse(isMemorySizeValid)
}
}
@Test
void testShouldVerifyParallelismBoundary() throws Exception {
// Arrange
final int parallelism = 4
// Act
boolean valid = Argon2SecureHasher.isParallelismValid(parallelism)
// Assert
assertTrue(valid)
}
@Test
void testShouldFailParallelismBoundary() throws Exception {
// Arrange
def parallelisms = [-8, 0, 16777220, 16778000]
// Act
def results = parallelisms.collect { parallelism ->
def isValid = Argon2SecureHasher.isParallelismValid(parallelism)
[parallelism, isValid]
}
// Assert
results.each { parallelism, isParallelismValid ->
logger.info("For parallelization factor ${parallelism}, parallelism is ${isParallelismValid ? "valid" : "invalid"}")
assertFalse(isParallelismValid)
}
}
@Test
void testShouldVerifyIterationsBoundary() throws Exception {
// Arrange
final int iterations = 4
// Act
boolean valid = Argon2SecureHasher.isIterationsValid(iterations)
// Assert
assertTrue(valid)
}
@Test
void testShouldFailIterationsBoundary() throws Exception {
// Arrange
def iterationCounts = [-50, -1, 0]
// Act
def results = iterationCounts.collect { iterations ->
def isValid = Argon2SecureHasher.isIterationsValid(iterations)
[iterations, isValid]
}
// Assert
results.each { iterations, isIterationsValid ->
logger.info("For iteration counts ${iterations}, iteration is ${isIterationsValid ? "valid" : "invalid"}")
assertFalse(isIterationsValid)
}
}
@Test
void testShouldVerifySaltLengthBoundary() throws Exception {
// Arrange
def saltLengths = [0, 64]
// Act and Assert
Argon2SecureHasher argon2SecureHasher = new Argon2SecureHasher()
saltLengths.forEach(saltLength -> {
assertTrue(argon2SecureHasher.isSaltLengthValid(saltLength))
})
}
@Test
void testShouldFailSaltLengthBoundary() throws Exception {
// Arrange
def saltLengths = [-16, 4]
// Act and Assert
Argon2SecureHasher argon2SecureHasher = new Argon2SecureHasher()
saltLengths.forEach(saltLength -> assertFalse(argon2SecureHasher.isSaltLengthValid(saltLength)))
}
@Test
void testShouldCreateHashOfDesiredLength() throws Exception {
// Arrange
def hashLengths = [16, 32]
final String PASSWORD = "password"
final byte[] SALT = [0x00] * 16
final byte[] EXPECTED_HASH = Hex.decode("411c9c87e7c91d8c8eacc418665bd2e1")
// Act
Map<Integer, byte[]> results = hashLengths.collectEntries { hashLength ->
Argon2SecureHasher ash = new Argon2SecureHasher(hashLength, 8, 1, 3)
def hash = ash.hashRaw(PASSWORD.bytes, SALT)
logger.info("Hashed password ${PASSWORD} with salt ${Hex.encode(SALT)} to ${Hex.encode(hash)}".toString())
[hashLength, hash]
}
// Assert
assertFalse(Arrays.equals(Arrays.copyOf(results[16], 16), Arrays.copyOf(results[32], 16)))
// Demonstrates that internal hash truncation is not supported
// assert results.every { int k, byte[] v -> v[0..15] as byte[] == EXPECTED_HASH}
}
}

View File

@ -1,656 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import at.favre.lib.crypto.bcrypt.BCrypt
import at.favre.lib.crypto.bcrypt.Radix64Encoder
import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.EncryptionMethod
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.crypto.Cipher
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
import java.nio.charset.StandardCharsets
import java.security.MessageDigest
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assertions.assertThrows
class BcryptCipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(BcryptCipherProviderGroovyTest.class)
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess"
private static List<EncryptionMethod> strongKDFEncryptionMethods
private static final int DEFAULT_KEY_LENGTH = 128
public static final String MICROBENCHMARK = "microbenchmark"
private static ArrayList<Integer> AES_KEY_LENGTHS
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
strongKDFEncryptionMethods = EncryptionMethod.values().findAll { it.isCompatibleWithStrongKDFs() }
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
AES_KEY_LENGTHS = [128, 192, 256]
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherWithExternalIVShouldBeInternallyConsistent() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("01" * 16 as char[])
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final int LONG_KEY_LENGTH = 256
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, LONG_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, iv, LONG_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testHashPWShouldMatchTestVectors() {
// Arrange
final byte[] PASSWORD = 'abcdefghijklmnopqrstuvwxyz'.getBytes(StandardCharsets.UTF_8)
final byte[] SALT = new Radix64Encoder.Default().decode('fVH8e28OQRj9tqiDXs1e1u'.getBytes(StandardCharsets.UTF_8))
final String EXPECTED_HASH = '$2a$10$fVH8e28OQRj9tqiDXs1e1uxpsjN0c7II7YPKXua2NAKYvM6iQk7dq'
final int WORK_FACTOR = 10
// Act
String libraryCalculatedHash = new String(BCrypt.withDefaults().hash(WORK_FACTOR, SALT, PASSWORD), StandardCharsets.UTF_8)
logger.info("Generated ${libraryCalculatedHash}")
BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher(WORK_FACTOR)
String secureHasherCalculatedHash = new String(bcryptSecureHasher.hashRaw(PASSWORD, SALT), StandardCharsets.UTF_8)
logger.info("Generated ${secureHasherCalculatedHash}")
// Assert
assertEquals(EXPECTED_HASH, secureHasherCalculatedHash)
assertEquals(EXPECTED_HASH, secureHasherCalculatedHash)
}
@Test
void testGetCipherShouldSupportExternalCompatibility() throws Exception {
// Arrange
final int WORK_FACTOR = 10
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(WORK_FACTOR)
final String PLAINTEXT = "This is a plaintext message."
final String PASSWORD = "thisIsABadPassword"
// These values can be generated by running `$ ./openssl_bcrypt` in the terminal
// The Ruby bcrypt gem does not expose the custom Radix64 decoder, so maintain the R64 encoding from the output and decode here
final byte[] SALT = new Radix64Encoder.Default().decode("LBVzJoPgh.85YCvnos4BKO".bytes)
final byte[] IV = Hex.decodeHex("bae8a9d935748a75ff0e0bbd95a4f024" as char[])
// $v2$w2$base64_salt_22__base64_hash_31
final String FULL_HASH = "\$2a\$10\$LBVzJoPgh.85YCvnos4BKOyYM.LRni6UbU4v/CEPBkmFIiigADJZi"
logger.info("Full Hash: ${FULL_HASH}")
final String HASH = FULL_HASH[-31..-1]
logger.info(" Hash: ${HASH.padLeft(60, " ")}")
logger.info(" B64 Salt: ${customB64Encode(SALT).padLeft(29, " ")}")
String extractedSalt = FULL_HASH[7..<29]
logger.info("Extracted Salt: ${extractedSalt}")
String extractedSaltHex = Hex.encodeHexString(customB64Decode(extractedSalt))
logger.info("Extracted Salt (hex): ${extractedSaltHex}")
logger.info(" Expected Salt (hex): ${Hex.encodeHexString(SALT)}")
final String CIPHER_TEXT = "d232b68e7aa38242d195c54b8f360d8b8d6b7580b190ffdeef99f5fe460bd6b0"
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT as char[])
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
logger.info("External cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
// Sanity check
Cipher rubyCipher = Cipher.getInstance(encryptionMethod.algorithm, "BC")
def rubyKey = new SecretKeySpec(Hex.decodeHex("01ea96ccc48a1d045bd7f461721b94a8" as char[]), "AES")
def ivSpec = new IvParameterSpec(IV)
rubyCipher.init(Cipher.ENCRYPT_MODE, rubyKey, ivSpec)
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes)
logger.info("Expected cipher text: ${Hex.encodeHexString(rubyCipherBytes)}")
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec)
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(rubyCipherBytes))
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(cipherBytes))
logger.sanity("Decrypted external cipher text and generated cipher text successfully")
// Sanity for hash generation
final String FULL_SALT = FULL_HASH[0..<29]
logger.sanity("Salt from external: ${FULL_SALT}")
String generatedHash = new String(BCrypt.withDefaults().hash(WORK_FACTOR, BcryptCipherProvider.extractRawSalt(FULL_SALT), PASSWORD.bytes))
logger.sanity("Generated hash: ${generatedHash}")
assertEquals(FULL_HASH, generatedHash)
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, FULL_SALT.bytes, IV, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
private static byte[] customB64Decode(String input) {
customB64Decode(input.bytes)
}
private static byte[] customB64Decode(byte[] input) {
new Radix64Encoder.Default().decode(input)
}
private static String customB64Encode(String input) {
customB64Encode(input.bytes)
}
private static String customB64Encode(byte[] input) {
new String(new Radix64Encoder.Default().encode(input), StandardCharsets.UTF_8)
}
@Test
void testGetCipherShouldHandleFullSalt() throws Exception {
// Arrange
final int WORK_FACTOR = 10
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(WORK_FACTOR)
final String PLAINTEXT = "This is a plaintext message."
final String PASSWORD = "thisIsABadPassword"
// These values can be generated by running `$ ./openssl_bcrypt.rb` in the terminal
final byte[] IV = Hex.decodeHex("bae8a9d935748a75ff0e0bbd95a4f024" as char[])
// $v2$w2$base64_salt_22__base64_hash_31
final String FULL_HASH = "\$2a\$10\$LBVzJoPgh.85YCvnos4BKOyYM.LRni6UbU4v/CEPBkmFIiigADJZi"
logger.info("Full Hash: ${FULL_HASH}")
final String FULL_SALT = FULL_HASH[0..<29]
logger.info(" Salt: ${FULL_SALT}")
final String HASH = FULL_HASH[-31..-1]
logger.info(" Hash: ${HASH.padLeft(60, " ")}")
String extractedSalt = FULL_HASH[7..<29]
logger.info("Extracted Salt: ${extractedSalt}")
String extractedSaltHex = Hex.encodeHexString(customB64Decode(extractedSalt))
logger.info("Extracted Salt (hex): ${extractedSaltHex}")
final String CIPHER_TEXT = "d232b68e7aa38242d195c54b8f360d8b8d6b7580b190ffdeef99f5fe460bd6b0"
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT as char[])
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
logger.info("External cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, FULL_SALT.bytes, IV, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
@Test
void testGetCipherShouldHandleUnformedSalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "thisIsABadPassword"
final def INVALID_SALTS = ['$ab$00$acbdefghijklmnopqrstuv', 'bad_salt', '$3a$11$', 'x', '$2a$10$']
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true))
logger.warn(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains("The salt must be of the format \$2a\$10\$gUVbkVzp79H8YaCOsCVZNu. To generate a salt, use BcryptCipherProvider#generateSalt"))
}
}
String bytesToBitString(byte[] bytes) {
bytes.collect {
String.format("%8s", Integer.toBinaryString(it & 0xFF)).replace(' ', '0')
}.join("")
}
String spaceString(String input, int blockSize = 4) {
input.collect { it.padLeft(blockSize, " ") }.join("")
}
@Test
void testGetCipherShouldRejectEmptySalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "thisIsABadPassword"
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Two different errors -- one explaining the no-salt method is not supported, and the other for an empty byte[] passed
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true))
logger.warn(iae.getMessage())
// Assert
assertTrue((iae.getMessage() =~ "The salt must be of the format .* To generate a salt, use BcryptCipherProvider#generateSalt").find())
}
@Test
void testGetCipherForDecryptShouldRequireIV() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("00" * 16 as char[])
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@Test
void testGetCipherShouldAcceptValidKeyLengths() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("01" * 16 as char[])
// Currently only AES ciphers are compatible with Bcrypt, so redundant to test all algorithms
final def VALID_KEY_LENGTHS = AES_KEY_LENGTHS
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
VALID_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherShouldNotAcceptInvalidKeyLengths() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("00" * 16 as char[])
// Currently only AES ciphers are compatible with Bcrypt, so redundant to test all algorithms
final def INVALID_KEY_LENGTHS = [-1, 40, 64, 112, 512]
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
INVALID_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
// Assert
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
}
}
@Test
void testGenerateSaltShouldUseProvidedWorkFactor() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(11)
int workFactor = cipherProvider.getWorkFactor()
// Act
final byte[] saltBytes = cipherProvider.generateSalt()
String salt = new String(saltBytes)
logger.info("Salt: ${salt}")
// Assert
assertTrue((salt =~ /^\$2[axy]\$\d{2}\$/).find())
assertTrue(salt.contains("\$" + workFactor + "\$"))
}
/**
* For {@code 1.12.0} the key derivation process was changed. Previously, the entire hash output
* ({@code $2a$10$9XUQnxGEUsRdLqEhxY3xNujOQQkW3spKqxssi.Ox39VhhxB.z4496}) was fed to {@code SHA-512}
* to stretch the hash output to a custom key length (128, 192, or 256 bits) because the Bcrypt hash
* output length is fixed at 184 bits. The new key derivation process only feeds the <em>non-salt
* hash output</em> (({@code jOQQkW3spKqxssi.Ox39VhhxB.z4496})) into the digest.
* @throws Exception
*/
@Test
void testGetCipherShouldUseHashOutputOnlyToDeriveKey() throws Exception {
// Arrange
BcryptCipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
String saltString = new String(SALT, StandardCharsets.UTF_8)
logger.info("Using fixed Bcrypt salt: ${saltString}")
// Determine the expected key bytes using the new key derivation process
BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher(cipherProvider.getWorkFactor(), cipherProvider.getDefaultSaltLength())
byte[] rawSaltBytes = BcryptCipherProvider.extractRawSalt(saltString)
byte[] hashOutputBytes = bcryptSecureHasher.hashRaw(PASSWORD.getBytes(StandardCharsets.UTF_8), rawSaltBytes)
logger.info("Raw hash output (${hashOutputBytes.length}): ${Hex.encodeHexString(hashOutputBytes)}")
MessageDigest sha512 = MessageDigest.getInstance("SHA-512", "BC")
byte[] keyDigestBytes = sha512.digest(hashOutputBytes[-31..-1] as byte[])
logger.info("Key digest (${keyDigestBytes.length}): ${Hex.encodeHexString(keyDigestBytes)}")
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Expected key verification
int keyLength = CipherUtility.parseKeyLengthFromAlgorithm(em.getAlgorithm())
byte[] derivedKeyBytes = Arrays.copyOf(keyDigestBytes, keyLength / 8 as int)
logger.info("Derived key (${derivedKeyBytes.length}): ${Hex.encodeHexString(derivedKeyBytes)}")
Cipher verificationCipher = Cipher.getInstance(em.getAlgorithm())
verificationCipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(derivedKeyBytes, em.algorithm), new IvParameterSpec(iv))
byte[] verificationBytes = verificationCipher.doFinal(cipherBytes)
String verificationRecovered = new String(verificationBytes, StandardCharsets.UTF_8)
logger.info("Verified: ${verificationRecovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
assertEquals(PLAINTEXT, verificationRecovered)
}
}
@Test
void testGetCipherShouldBeBackwardCompatibleWithFullHashKeyDerivation() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption using the legacy key derivation process
Cipher cipher = cipherProvider.getInitializedCipher(em, PASSWORD, SALT, new byte[0], DEFAULT_KEY_LENGTH, true, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getLegacyDecryptCipher(em, PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherShouldHandleNullSalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4)
final String PASSWORD = "shortPassword"
final byte[] SALT = null
final EncryptionMethod em = EncryptionMethod.AES_CBC
// Act
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
IllegalArgumentException encryptIae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true))
logger.warn("Encrypt error: " + encryptIae.getMessage())
byte[] cipherBytes = PLAINTEXT.reverse().getBytes(StandardCharsets.UTF_8)
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
IllegalArgumentException decryptIae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, [0x00] * 16 as byte[], DEFAULT_KEY_LENGTH, false))
logger.warn("Decrypt error: " + decryptIae.getMessage())
// Assert
assertTrue(encryptIae.getMessage().contains("The salt must be of the format"))
assertTrue(decryptIae.getMessage().contains("The salt must be of the format"))
}
@Disabled("This test can be run on a specific machine to evaluate if the default work factor is sufficient")
@Test
void testDefaultConstructorShouldProvideStrongWorkFactor() {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider()
// Values taken from http://wildlyinaccurate.com/bcrypt-choosing-a-work-factor/ and http://security.stackexchange.com/questions/17207/recommended-of-rounds-for-bcrypt
// Calculate the work factor to reach 500 ms
int minimumWorkFactor = calculateMinimumWorkFactor()
logger.info("Determined minimum safe work factor to be ${minimumWorkFactor}")
// Act
int workFactor = cipherProvider.getWorkFactor()
logger.info("Default work factor ${workFactor}")
// Assert
assertTrue("The default work factor for BcryptCipherProvider is too weak. Please update the default value to a stronger level.", workFactor >= minimumWorkFactor)
}
/**
* Returns the work factor required for a derivation to exceed 500 ms on this machine. Code adapted from http://security.stackexchange.com/questions/17207/recommended-of-rounds-for-bcrypt
*
* @return the minimum bcrypt work factor
*/
private static int calculateMinimumWorkFactor() {
// High start-up cost, so run multiple times for better benchmarking
final int RUNS = 10
// Benchmark using a work factor of 5 (the second-lowest allowed)
int workFactor = 5
String salt = new BcryptCipherProvider(5).generateSalt()
// Run once to prime the system
double duration = time {
BCrypt.hashpw(MICROBENCHMARK, salt)
}
logger.info("First run of work factor ${workFactor} took ${duration} ms (ignored)")
def durations = []
RUNS.times { int i ->
duration = time {
BCrypt.hashpw(MICROBENCHMARK, salt)
}
logger.info("Work factor ${workFactor} took ${duration} ms")
durations << duration
}
duration = durations.sum() / durations.size()
logger.info("Work factor ${workFactor} averaged ${duration} ms")
// Increasing the work factor by 1 would double the run time
// Keep increasing N until the estimated duration is over 500 ms
while (duration < 500) {
workFactor += 1
duration *= 2
}
logger.info("Returning work factor ${workFactor} for ${duration} ms")
return workFactor
}
private static double time(Closure c) {
long start = System.nanoTime()
c.call()
long end = System.nanoTime()
return (end - start) / 1_000_000.0
}
}

View File

@ -1,356 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import at.favre.lib.crypto.bcrypt.Radix64Encoder
import org.bouncycastle.util.encoders.Hex
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.nio.charset.StandardCharsets
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class BcryptSecureHasherTest {
private static final Logger logger = LoggerFactory.getLogger(BcryptSecureHasher)
@BeforeAll
static void setupOnce() throws Exception {
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Test
void testShouldBeDeterministicWithStaticSalt() {
// Arrange
int cost = 4
logger.info("Generating Bcrypt hash for cost factor: ${cost}")
int testIterations = 10
byte[] inputBytes = "This is a sensitive value".bytes
final String EXPECTED_HASH_HEX = "24326124303424526b6a4559512f526245447959554b6553304471622e596b4c5331655a2e6c61586550484c69464d783937564c566d47354250454f"
BcryptSecureHasher bcryptSH = new BcryptSecureHasher(cost)
def results = []
// Act
testIterations.times { int i ->
byte[] hash = bcryptSH.hashRaw(inputBytes)
String hashHex = new String(Hex.encode(hash))
logger.info("Generated hash: ${hashHex}")
results << hashHex
}
// Assert
results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
}
@Test
void testShouldBeDifferentWithRandomSalt() {
// Arrange
int cost = 4
int saltLength = 16
logger.info("Generating Bcrypt hash for cost factor: ${cost}, salt length: ${saltLength}")
int testIterations = 10
byte[] inputBytes = "This is a sensitive value".bytes
final String EXPECTED_HASH_HEX = "24326124303424546d6c47615342546447463061574d6755324673642e38675a347a6149356d6b4d50594c542e344e68337962455a4678384b676a75"
BcryptSecureHasher bcryptSH = new BcryptSecureHasher(cost, saltLength)
def results = []
// Act
testIterations.times { int i ->
byte[] hash = bcryptSH.hashRaw(inputBytes)
String hashHex = Hex.encode(hash)
logger.info("Generated hash: ${hashHex}")
results << hashHex
}
// Assert
assertEquals(results.size(), results.unique().size())
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
}
@Test
void testShouldHandleArbitrarySalt() {
// Arrange
int cost = 4
logger.info("Generating Bcrypt hash for cost factor: ${cost}")
def input = "This is a sensitive value"
byte[] inputBytes = input.bytes
final String EXPECTED_HASH_HEX = "24326124303424526b6a4559512f526245447959554b6553304471622e596b4c5331655a2e6c61586550484c69464d783937564c566d47354250454f"
final String EXPECTED_HASH_BASE64 = "JDJhJDA0JFJrakVZUS9SYkVEeVlVS2VTMERxYi5Za0xTMWVaLmxhWGVQSExpRk14OTdWTFZtRzVCUEVP"
final byte[] EXPECTED_HASH_BYTES = Hex.decode(EXPECTED_HASH_HEX)
// Static salt instance
BcryptSecureHasher staticSaltHasher = new BcryptSecureHasher(cost)
BcryptSecureHasher arbitrarySaltHasher = new BcryptSecureHasher(cost, 16)
final byte[] STATIC_SALT = AbstractSecureHasher.STATIC_SALT
final String DIFFERENT_STATIC_SALT = "Diff Static Salt"
// Act
byte[] staticSaltHash = staticSaltHasher.hashRaw(inputBytes)
byte[] arbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, STATIC_SALT)
byte[] differentArbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, DIFFERENT_STATIC_SALT.getBytes(StandardCharsets.UTF_8))
byte[] differentSaltHash = arbitrarySaltHasher.hashRaw(inputBytes)
String staticSaltHashHex = staticSaltHasher.hashHex(input)
String arbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8))
String differentArbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, DIFFERENT_STATIC_SALT)
String differentSaltHashHex = arbitrarySaltHasher.hashHex(input)
String staticSaltHashBase64 = staticSaltHasher.hashBase64(input)
String arbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8))
String differentArbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, DIFFERENT_STATIC_SALT)
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
}
@Test
void testShouldValidateArbitrarySalt() {
// Arrange
int cost = 4
logger.info("Generating Bcrypt hash for cost factor: ${cost}")
def input = "This is a sensitive value"
byte[] inputBytes = input.bytes
// Static salt instance
BcryptSecureHasher secureHasher = new BcryptSecureHasher(cost, 16)
final byte[] STATIC_SALT = "bad_sal".bytes
assertThrows(IllegalArgumentException.class, { -> new BcryptSecureHasher(cost, 7) })
assertThrows(RuntimeException.class, { -> secureHasher.hashRaw(inputBytes, STATIC_SALT) })
assertThrows(RuntimeException.class, { -> secureHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8)) })
assertThrows(RuntimeException.class, { -> secureHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8)) })
}
@Test
void testShouldFormatHex() {
// Arrange
String input = "This is a sensitive value"
final String EXPECTED_HASH_HEX = "24326124313224526b6a4559512f526245447959554b6553304471622e5852696135344d4e356c5a44515243575874516c4c696d476669635a776871"
BcryptSecureHasher bcryptSH = new BcryptSecureHasher()
// Act
String hashHex = bcryptSH.hashHex(input)
logger.info("Generated hash: ${hashHex}")
// Assert
assertEquals(EXPECTED_HASH_HEX, hashHex)
}
@Test
void testShouldFormatBase64() {
// Arrange
String input = "This is a sensitive value"
final String EXPECTED_HASH_BASE64 = "JDJhJDEyJFJrakVZUS9SYkVEeVlVS2VTMERxYi5YUmlhNTRNTjVsWkRRUkNXWHRRbExpbUdmaWNad2hx"
BcryptSecureHasher bcryptSH = new BcryptSecureHasher()
// Act
String hashB64 = bcryptSH.hashBase64(input)
logger.info("Generated hash: ${hashB64}")
// Assert
assertEquals(EXPECTED_HASH_BASE64, hashB64)
}
@Test
void testShouldHandleNullInput() {
// Arrange
List<String> inputs = [null, ""]
final String EXPECTED_HASH_HEX = ""
final String EXPECTED_HASH_BASE64 = ""
BcryptSecureHasher bcryptSH = new BcryptSecureHasher()
def hexResults = []
def B64Results = []
// Act
inputs.each { String input ->
String hashHex = bcryptSH.hashHex(input)
logger.info("Generated hex-encoded hash: ${hashHex}")
hexResults << hashHex
String hashB64 = bcryptSH.hashBase64(input)
logger.info("Generated B64-encoded hash: ${hashB64}")
B64Results << hashB64
}
// Assert
hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result))
}
/**
* This test can have the minimum time threshold updated to determine if the performance
* is still sufficient compared to the existing threat model.
*/
@EnabledIfSystemProperty(named = "nifi.test.performance", matches = "true")
@Test
void testDefaultCostParamsShouldBeSufficient() {
// Arrange
int testIterations = 100
byte[] inputBytes = "This is a sensitive value".bytes
BcryptSecureHasher bcryptSH = new BcryptSecureHasher()
def results = []
def resultDurations = []
// Act
testIterations.times { int i ->
long startNanos = System.nanoTime()
byte[] hash = bcryptSH.hashRaw(inputBytes)
long endNanos = System.nanoTime()
long durationNanos = endNanos - startNanos
String hashHex = Hex.encode(hash)
logger.info("Generated hash: ${hashHex} in ${durationNanos} ns")
results << hashHex
resultDurations << durationNanos
}
// Assert
final long MIN_DURATION_NANOS = 75_000_000 // 75 ms
assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
}
@Test
void testShouldVerifyCostBoundary() throws Exception {
// Arrange
final int cost = 14
// Act and Assert
assertTrue(BcryptSecureHasher.isCostValid(cost))
}
@Test
void testShouldFailCostBoundary() throws Exception {
// Arrange
def costFactors = [-8, 0, 40]
// Act and Assert
costFactors.forEach(costFactor -> assertFalse(BcryptSecureHasher.isCostValid(costFactor)))
}
@Test
void testShouldVerifySaltLengthBoundary() throws Exception {
// Arrange
def saltLengths = [0, 16]
// Act and Assert
BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher()
saltLengths.forEach(saltLength -> assertTrue(bcryptSecureHasher.isSaltLengthValid(saltLength)))
}
@Test
void testShouldFailSaltLengthBoundary() throws Exception {
// Arrange
def saltLengths = [-8, 1]
// Act and Assert
BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher()
saltLengths.forEach(saltLength -> assertFalse(bcryptSecureHasher.isSaltLengthValid(saltLength)))
}
@Test
void testShouldConvertRadix64ToBase64() {
// Arrange
final String INPUT_RADIX_64 = "mm7MiKjvXVYCujVUlKRKiu"
final byte[] EXPECTED_BYTES = new Radix64Encoder.Default().decode(INPUT_RADIX_64.bytes)
logger.info("Plain bytes: ${Hex.encode(EXPECTED_BYTES)}")
// Uses standard Base64 library but removes padding chars
final String EXPECTED_MIME_B64 = Base64.encoder.encodeToString(EXPECTED_BYTES).replaceAll(/=/, '')
// Act
String convertedBase64 = BcryptSecureHasher.convertBcryptRadix64ToMimeBase64(INPUT_RADIX_64)
logger.info("Converted (R64) ${INPUT_RADIX_64} to (B64) ${convertedBase64}")
String convertedRadix64 = BcryptSecureHasher.convertMimeBase64ToBcryptRadix64(convertedBase64)
logger.info("Converted (B64) ${convertedBase64} to (R64) ${convertedRadix64}")
// Assert
assertEquals(EXPECTED_MIME_B64, convertedBase64)
assertEquals(INPUT_RADIX_64, convertedRadix64)
}
@Test
void testConvertRadix64ToBase64ShouldHandlePeriod() {
// Arrange
final String INPUT_RADIX_64 = "75x373yP7atxMD3pVgsdO."
final byte[] EXPECTED_BYTES = new Radix64Encoder.Default().decode(INPUT_RADIX_64.bytes)
logger.info("Plain bytes: ${Hex.encode(EXPECTED_BYTES)}")
// Uses standard Base64 library but removes padding chars
final String EXPECTED_MIME_B64 = Base64.encoder.encodeToString(EXPECTED_BYTES).replaceAll(/=/, '')
// Act
String convertedBase64 = BcryptSecureHasher.convertBcryptRadix64ToMimeBase64(INPUT_RADIX_64)
logger.info("Converted (R64) ${INPUT_RADIX_64} to (B64) ${convertedBase64}")
String convertedRadix64 = BcryptSecureHasher.convertMimeBase64ToBcryptRadix64(convertedBase64)
logger.info("Converted (B64) ${convertedBase64} to (R64) ${convertedRadix64}")
// Assert
assertEquals(EXPECTED_MIME_B64, convertedBase64)
assertEquals(INPUT_RADIX_64, convertedRadix64)
}
}

View File

@ -1,301 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.EncryptionMethod
import org.apache.nifi.security.util.KeyDerivationFunction
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertTrue
class CipherUtilityGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(CipherUtilityGroovyTest.class)
// TripleDES must precede DES for automatic grouping precedence
private static final List<String> CIPHERS = ["AES", "TRIPLEDES", "DES", "RC2", "RC4", "RC5", "TWOFISH"]
private static final List<String> SYMMETRIC_ALGORITHMS = EncryptionMethod.values().findAll { it.algorithm.startsWith("PBE") || it.algorithm.startsWith("AES") }*.algorithm
private static final Map<String, List<String>> ALGORITHMS_MAPPED_BY_CIPHER = SYMMETRIC_ALGORITHMS.groupBy { String algorithm -> CIPHERS.find { algorithm.contains(it) } }
// Manually mapped as of 03/21/21 1.13.0
private static final Map<Integer, List<String>> ALGORITHMS_MAPPED_BY_KEY_LENGTH = [
(40) : ["PBEWITHSHAAND40BITRC2-CBC",
"PBEWITHSHAAND40BITRC4"],
(64) : ["PBEWITHMD5ANDDES",
"PBEWITHSHA1ANDDES"],
(112): ["PBEWITHSHAAND2-KEYTRIPLEDES-CBC",
"PBEWITHSHAAND3-KEYTRIPLEDES-CBC"],
(128): ["PBEWITHMD5AND128BITAES-CBC-OPENSSL",
"PBEWITHMD5ANDRC2",
"PBEWITHSHA1ANDRC2",
"PBEWITHSHA256AND128BITAES-CBC-BC",
"PBEWITHSHAAND128BITAES-CBC-BC",
"PBEWITHSHAAND128BITRC2-CBC",
"PBEWITHSHAAND128BITRC4",
"PBEWITHSHAANDTWOFISH-CBC",
"AES/CBC/NoPadding",
"AES/CBC/PKCS7Padding",
"AES/CTR/NoPadding",
"AES/GCM/NoPadding"],
(192): ["PBEWITHMD5AND192BITAES-CBC-OPENSSL",
"PBEWITHSHA256AND192BITAES-CBC-BC",
"PBEWITHSHAAND192BITAES-CBC-BC",
"AES/CBC/NoPadding",
"AES/CBC/PKCS7Padding",
"AES/CTR/NoPadding",
"AES/GCM/NoPadding"],
(256): ["PBEWITHMD5AND256BITAES-CBC-OPENSSL",
"PBEWITHSHA256AND256BITAES-CBC-BC",
"PBEWITHSHAAND256BITAES-CBC-BC",
"AES/CBC/NoPadding",
"AES/CBC/PKCS7Padding",
"AES/CTR/NoPadding",
"AES/GCM/NoPadding"]
]
@BeforeAll
static void setUpOnce() {
Security.addProvider(new BouncyCastleProvider())
// Fix because TRIPLEDES -> DESede
def tripleDESAlgorithms = ALGORITHMS_MAPPED_BY_CIPHER.remove("TRIPLEDES")
ALGORITHMS_MAPPED_BY_CIPHER.put("DESede", tripleDESAlgorithms)
logger.info("Mapped algorithms: ${ALGORITHMS_MAPPED_BY_CIPHER}")
}
@Test
void testShouldParseCipherFromAlgorithm() {
// Arrange
final def EXPECTED_ALGORITHMS = ALGORITHMS_MAPPED_BY_CIPHER
// Act
SYMMETRIC_ALGORITHMS.each { String algorithm ->
String cipher = CipherUtility.parseCipherFromAlgorithm(algorithm)
logger.info("Extracted ${cipher} from ${algorithm}")
// Assert
assertTrue(EXPECTED_ALGORITHMS.get(cipher).contains(algorithm))
}
}
@Test
void testShouldParseKeyLengthFromAlgorithm() {
// Arrange
final def EXPECTED_ALGORITHMS = ALGORITHMS_MAPPED_BY_KEY_LENGTH
// Act
SYMMETRIC_ALGORITHMS.each { String algorithm ->
int keyLength = CipherUtility.parseKeyLengthFromAlgorithm(algorithm)
logger.info("Extracted ${keyLength} from ${algorithm}")
// Assert
assertTrue(EXPECTED_ALGORITHMS.get(keyLength).contains(algorithm))
}
}
@Test
void testShouldDetermineValidKeyLength() {
// Arrange
// Act
ALGORITHMS_MAPPED_BY_KEY_LENGTH.each { int keyLength, List<String> algorithms ->
algorithms.each { String algorithm ->
logger.info("Checking ${keyLength} for ${algorithm}")
// Assert
assertTrue(CipherUtility.isValidKeyLength(keyLength, CipherUtility.parseCipherFromAlgorithm(algorithm)))
}
}
}
@Test
void testShouldDetermineInvalidKeyLength() {
// Arrange
// Act
ALGORITHMS_MAPPED_BY_KEY_LENGTH.each { int keyLength, List<String> algorithms ->
algorithms.each { String algorithm ->
def invalidKeyLengths = [-1, 0, 1]
if (algorithm =~ "RC\\d") {
invalidKeyLengths += [39, 2049]
} else {
invalidKeyLengths += keyLength + 1
}
logger.info("Checking ${invalidKeyLengths.join(", ")} for ${algorithm}")
// Assert
invalidKeyLengths.forEach(invalidKeyLength -> assertFalse(CipherUtility.isValidKeyLength(invalidKeyLength, CipherUtility.parseCipherFromAlgorithm(algorithm))))
}
}
}
@Test
void testShouldDetermineValidKeyLengthForAlgorithm() {
// Arrange
// Act
ALGORITHMS_MAPPED_BY_KEY_LENGTH.each { int keyLength, List<String> algorithms ->
algorithms.each { String algorithm ->
logger.info("Checking ${keyLength} for ${algorithm}")
// Assert
assertTrue(CipherUtility.isValidKeyLengthForAlgorithm(keyLength, algorithm))
}
}
}
@Test
void testShouldDetermineInvalidKeyLengthForAlgorithm() {
// Arrange
// Act
ALGORITHMS_MAPPED_BY_KEY_LENGTH.each { int keyLength, List<String> algorithms ->
algorithms.each { String algorithm ->
def invalidKeyLengths = [-1, 0, 1]
if (algorithm =~ "RC\\d") {
invalidKeyLengths += [39, 2049]
} else {
invalidKeyLengths += keyLength + 1
}
logger.info("Checking ${invalidKeyLengths.join(", ")} for ${algorithm}")
// Assert
invalidKeyLengths.forEach(invalidKeyLength -> assertFalse(CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm)))
}
}
// Extra hard-coded checks
String algorithm = "PBEWITHSHA256AND256BITAES-CBC-BC"
int invalidKeyLength = 192
logger.info("Checking ${invalidKeyLength} for ${algorithm}")
assertFalse(CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm))
}
@Test
void testShouldGetValidKeyLengthsForAlgorithm() {
// Arrange
def rcKeyLengths = (40..2048).asList()
def CIPHER_KEY_SIZES = [
AES : [128, 192, 256],
DES : [56, 64],
DESede : [56, 64, 112, 128, 168, 192],
RC2 : rcKeyLengths,
RC4 : rcKeyLengths,
RC5 : rcKeyLengths,
TWOFISH: [128, 192, 256]
]
def SINGLE_KEY_SIZE_ALGORITHMS = EncryptionMethod.values()*.algorithm.findAll { CipherUtility.parseActualKeyLengthFromAlgorithm(it) != -1 }
logger.info("Single key size algorithms: ${SINGLE_KEY_SIZE_ALGORITHMS}")
def MULTIPLE_KEY_SIZE_ALGORITHMS = EncryptionMethod.values()*.algorithm - SINGLE_KEY_SIZE_ALGORITHMS
MULTIPLE_KEY_SIZE_ALGORITHMS.removeAll { it.contains("PGP") }
logger.info("Multiple key size algorithms: ${MULTIPLE_KEY_SIZE_ALGORITHMS}")
// Act
SINGLE_KEY_SIZE_ALGORITHMS.each { String algorithm ->
def EXPECTED_KEY_SIZES = [CipherUtility.parseKeyLengthFromAlgorithm(algorithm)]
def validKeySizes = CipherUtility.getValidKeyLengthsForAlgorithm(algorithm)
logger.info("Checking ${algorithm} ${validKeySizes} against expected ${EXPECTED_KEY_SIZES}")
// Assert
assertEquals(EXPECTED_KEY_SIZES, validKeySizes)
}
// Act
MULTIPLE_KEY_SIZE_ALGORITHMS.each { String algorithm ->
String cipher = CipherUtility.parseCipherFromAlgorithm(algorithm)
def EXPECTED_KEY_SIZES = CIPHER_KEY_SIZES[cipher]
def validKeySizes = CipherUtility.getValidKeyLengthsForAlgorithm(algorithm)
logger.info("Checking ${algorithm} ${validKeySizes} against expected ${EXPECTED_KEY_SIZES}")
// Assert
assertEquals(EXPECTED_KEY_SIZES, validKeySizes)
}
}
@Test
void testShouldFindSequence() {
// Arrange
byte[] license = """Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
""".bytes
byte[] apache = "Apache".bytes
byte[] software = "Software".bytes
byte[] asf = "ASF".bytes
byte[] kafka = "Kafka".bytes
// Act
int apacheIndex = CipherUtility.findSequence(license, apache)
logger.info("Looking for ${Hex.encodeHexString(apache)}; found at ${apacheIndex}")
int softwareIndex = CipherUtility.findSequence(license, software)
logger.info("Looking for ${Hex.encodeHexString(software)}; found at ${softwareIndex}")
int asfIndex = CipherUtility.findSequence(license, asf)
logger.info("Looking for ${Hex.encodeHexString(asf)}; found at ${asfIndex}")
int kafkaIndex = CipherUtility.findSequence(license, kafka)
logger.info("Looking for ${Hex.encodeHexString(kafka)}; found at ${kafkaIndex}")
// Assert
assertEquals(16, apacheIndex)
assertEquals(23, softwareIndex)
assertEquals(44, asfIndex)
assertEquals(-1, kafkaIndex)
}
@Test
void testShouldExtractRawSalt() {
// Arrange
byte[] PLAIN_SALT = [0xab] * 16
String ARGON2_SALT = Argon2CipherProvider.formSalt(PLAIN_SALT, 8, 1, 1)
String BCRYPT_SALT = BcryptCipherProvider.formatSaltForBcrypt(PLAIN_SALT, 10)
String SCRYPT_SALT = ScryptCipherProvider.formatSaltForScrypt(PLAIN_SALT, 10, 1, 1)
// Act
Map<Object, byte[]> results = KeyDerivationFunction.values().findAll { !it.isStrongKDF() }.collectEntries { KeyDerivationFunction weakKdf ->
[weakKdf, CipherUtility.extractRawSalt(PLAIN_SALT, weakKdf)]
}
results.put(KeyDerivationFunction.ARGON2, CipherUtility.extractRawSalt(ARGON2_SALT.bytes, KeyDerivationFunction.ARGON2))
results.put(KeyDerivationFunction.BCRYPT, CipherUtility.extractRawSalt(BCRYPT_SALT.bytes, KeyDerivationFunction.BCRYPT))
results.put(KeyDerivationFunction.SCRYPT, CipherUtility.extractRawSalt(SCRYPT_SALT.bytes, KeyDerivationFunction.SCRYPT))
results.put(KeyDerivationFunction.PBKDF2, CipherUtility.extractRawSalt(PLAIN_SALT, KeyDerivationFunction.PBKDF2))
// Assert
results.values().forEach(v -> assertArrayEquals(PLAIN_SALT, v))
}
}

View File

@ -1,105 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License") you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertTrue
class HashAlgorithmTest {
private static final Logger logger = LoggerFactory.getLogger(HashAlgorithmTest.class)
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Test
void testDetermineBrokenAlgorithms() throws Exception {
// Arrange
def algorithms = HashAlgorithm.values()
// Act
def brokenAlgorithms = algorithms.findAll { !it.isStrongAlgorithm() }
logger.info("Broken algorithms: ${brokenAlgorithms}")
// Assert
assertEquals([HashAlgorithm.MD2, HashAlgorithm.MD5, HashAlgorithm.SHA1], brokenAlgorithms)
}
@Test
void testShouldBuildAllowableValueDescription() {
// Arrange
def algorithms = HashAlgorithm.values()
// Act
def descriptions = algorithms.collect { HashAlgorithm algorithm ->
algorithm.buildAllowableValueDescription()
}
// Assert
descriptions.forEach(description -> assertTrue((description =~ /.* \(\d+ byte output\).*/).find()) )
descriptions.stream()
.filter(description -> (description =~ "MD2|MD5|SHA-1").find() )
.forEach(description -> assertTrue(description.contains("WARNING")))
}
@Test
void testDetermineBlake2Algorithms() {
def algorithms = HashAlgorithm.values()
// Act
def blake2Algorithms = algorithms.findAll { it.isBlake2() }
logger.info("Blake2 algorithms: ${blake2Algorithms}")
// Assert
assertEquals([HashAlgorithm.BLAKE2_160, HashAlgorithm.BLAKE2_256, HashAlgorithm.BLAKE2_384, HashAlgorithm.BLAKE2_512], blake2Algorithms)
}
@Test
void testShouldMatchAlgorithmByName() {
// Arrange
def algorithms = HashAlgorithm.values()
// Act
algorithms.each { HashAlgorithm algorithm ->
def transformedNames = [algorithm.name, algorithm.name.toUpperCase(), algorithm.name.toLowerCase()]
logger.info("Trying with names: ${transformedNames}")
transformedNames.each { String name ->
HashAlgorithm found = HashAlgorithm.fromName(name)
// Assert
assertEquals(name.toUpperCase(), found.name)
}
}
}
}

View File

@ -1,429 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License") you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.apache.nifi.components.AllowableValue
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.util.encoders.Hex
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.nio.charset.Charset
import java.nio.charset.StandardCharsets
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertInstanceOf
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class HashServiceTest {
private static final Logger logger = LoggerFactory.getLogger(HashServiceTest.class)
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Test
void testShouldHashValue() {
// Arrange
final HashAlgorithm algorithm = HashAlgorithm.SHA256
final String KNOWN_VALUE = "apachenifi"
final String EXPECTED_HASH = "dc4bd945723b9c234f1be408e8ceb78660b481008b8ab5b71eb2aa3b4f08357a"
final byte[] EXPECTED_HASH_BYTES = Hex.decode(EXPECTED_HASH)
Closure threeArgString = { -> HashService.hashValue(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8) }
Closure twoArgString = { -> HashService.hashValue(algorithm, KNOWN_VALUE) }
Closure threeArgStringRaw = { -> HashService.hashValueRaw(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8) }
Closure twoArgStringRaw = { -> HashService.hashValueRaw(algorithm, KNOWN_VALUE) }
Closure twoArgBytesRaw = { -> HashService.hashValueRaw(algorithm, KNOWN_VALUE.bytes) }
def scenarios = [threeArgString : threeArgString,
twoArgString : twoArgString,
threeArgStringRaw: threeArgStringRaw,
twoArgStringRaw : twoArgStringRaw,
twoArgBytesRaw : twoArgBytesRaw,
]
// Act
scenarios.each { String name, Closure closure ->
def result = closure.call()
logger.info("${name.padLeft(20)}: ${result.class.simpleName.padLeft(8)} ${result}")
// Assert
if (result instanceof byte[]) {
assertArrayEquals(EXPECTED_HASH_BYTES, result)
} else {
assertEquals(EXPECTED_HASH, result)
}
}
}
@Test
void testHashValueShouldDifferOnDifferentEncodings() {
// Arrange
final HashAlgorithm algorithm = HashAlgorithm.SHA256
final String KNOWN_VALUE = "apachenifi"
// Act
String utf8Hash = HashService.hashValue(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8)
logger.info(" UTF-8: ${utf8Hash}")
String utf16Hash = HashService.hashValue(algorithm, KNOWN_VALUE, StandardCharsets.UTF_16)
logger.info("UTF-16: ${utf16Hash}")
// Assert
assertNotEquals(utf8Hash, utf16Hash)
}
/**
* This test ensures that the service properly handles UTF-16 encoded data to return it without
* the Big Endian Byte Order Mark (BOM). Java treats UTF-16 encoded data without a BOM as Big Endian by default on decoding, but when <em>encoding</em>, it inserts a BE BOM in the data.
*
* Examples:
*
* "apachenifi"
*
* * UTF-8: 0x61 0x70 0x61 0x63 0x68 0x65 0x6E 0x69 0x66 0x69
* * UTF-16: 0xFE 0xFF 0x00 0x61 0x00 0x70 0x00 0x61 0x00 0x63 0x00 0x68 0x00 0x65 0x00 0x6E 0x00 0x69 0x00 0x66 0x00 0x69
* * UTF-16LE: 0x61 0x00 0x70 0x00 0x61 0x00 0x63 0x00 0x68 0x00 0x65 0x00 0x6E 0x00 0x69 0x00 0x66 0x00 0x69 0x00
* * UTF-16BE: 0x00 0x61 0x00 0x70 0x00 0x61 0x00 0x63 0x00 0x68 0x00 0x65 0x00 0x6E 0x00 0x69 0x00 0x66 0x00 0x69
*
* The result of "UTF-16" decoding should have the 0xFE 0xFF stripped on return by encoding in UTF-16BE directly, which will not insert a BOM.
*
* See also: <a href="https://unicode.org/faq/utf_bom.html#bom10">https://unicode.org/faq/utf_bom.html#bom10</a>
*/
@Test
void testHashValueShouldHandleUTF16BOMIssue() {
// Arrange
HashAlgorithm algorithm = HashAlgorithm.SHA256
final String KNOWN_VALUE = "apachenifi"
List<Charset> charsets = [StandardCharsets.UTF_8, StandardCharsets.UTF_16, StandardCharsets.UTF_16LE, StandardCharsets.UTF_16BE]
charsets.each { Charset charset ->
logger.info("[${charset.name().padLeft(9)}]: ${printHexBytes(KNOWN_VALUE, charset)}")
}
final def EXPECTED_SHA_256_HASHES = [
"utf_8" : "dc4bd945723b9c234f1be408e8ceb78660b481008b8ab5b71eb2aa3b4f08357a",
"utf_16" : "f370019c2a41a8285077beb839f7566240e2f0ca970cb67aed5836b89478df91",
"utf_16be": "f370019c2a41a8285077beb839f7566240e2f0ca970cb67aed5836b89478df91",
"utf_16le": "7e285dc64d3a8c3cb4e04304577eebbcb654f2245373874e48e597a8b8f15aff",
]
EXPECTED_SHA_256_HASHES.each { k, hash ->
logger.expected("SHA-256(${k.padLeft(9)}(${KNOWN_VALUE})) = ${hash}")
}
// Act
charsets.each { Charset charset ->
// Calculate the expected hash value given the character set
String hash = HashService.hashValue(algorithm, KNOWN_VALUE, charset)
logger.info("${algorithm.name}(${KNOWN_VALUE}, ${charset.name().padLeft(9)}) = ${hash}")
// Assert
assertEquals(EXPECTED_SHA_256_HASHES[translateStringToMapKey(charset.name())], hash)
}
}
@Test
void testHashValueShouldDefaultToUTF8() {
// Arrange
final HashAlgorithm algorithm = HashAlgorithm.SHA256
final String KNOWN_VALUE = "apachenifi"
// Act
String explicitUTF8Hash = HashService.hashValue(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8)
logger.info("Explicit UTF-8: ${explicitUTF8Hash}")
String implicitUTF8Hash = HashService.hashValue(algorithm, KNOWN_VALUE)
logger.info("Implicit UTF-8: ${implicitUTF8Hash}")
byte[] explicitUTF8HashBytes = HashService.hashValueRaw(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8)
logger.info("Explicit UTF-8 bytes: ${explicitUTF8HashBytes}")
byte[] implicitUTF8HashBytes = HashService.hashValueRaw(algorithm, KNOWN_VALUE)
logger.info("Implicit UTF-8 bytes: ${implicitUTF8HashBytes}")
byte[] implicitUTF8HashBytesDefault = HashService.hashValueRaw(algorithm, KNOWN_VALUE.bytes)
logger.info("Implicit UTF-8 bytes: ${implicitUTF8HashBytesDefault}")
// Assert
assertEquals(explicitUTF8Hash, implicitUTF8Hash)
assertArrayEquals(explicitUTF8HashBytes, implicitUTF8HashBytes)
assertArrayEquals(explicitUTF8HashBytes, implicitUTF8HashBytesDefault)
}
@Test
void testShouldRejectNullAlgorithm() {
// Arrange
final String KNOWN_VALUE = "apachenifi"
Closure threeArgString = { -> HashService.hashValue(null, KNOWN_VALUE, StandardCharsets.UTF_8) }
Closure twoArgString = { -> HashService.hashValue(null, KNOWN_VALUE) }
Closure threeArgStringRaw = { -> HashService.hashValueRaw(null, KNOWN_VALUE, StandardCharsets.UTF_8) }
Closure twoArgStringRaw = { -> HashService.hashValueRaw(null, KNOWN_VALUE) }
Closure twoArgBytesRaw = { -> HashService.hashValueRaw(null, KNOWN_VALUE.bytes) }
def scenarios = [threeArgString : threeArgString,
twoArgString : twoArgString,
threeArgStringRaw: threeArgStringRaw,
twoArgStringRaw : twoArgStringRaw,
twoArgBytesRaw : twoArgBytesRaw,
]
// Act
scenarios.entrySet().forEach(entry -> {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> entry.getValue().call())
assertTrue(iae.message.contains("The hash algorithm cannot be null"))
})
}
@Test
void testShouldRejectNullValue() {
// Arrange
final HashAlgorithm algorithm = HashAlgorithm.SHA256
Closure threeArgString = { -> HashService.hashValue(algorithm, null, StandardCharsets.UTF_8) }
Closure twoArgString = { -> HashService.hashValue(algorithm, null) }
Closure threeArgStringRaw = { -> HashService.hashValueRaw(algorithm, null, StandardCharsets.UTF_8) }
Closure twoArgStringRaw = { -> HashService.hashValueRaw(algorithm, null as String) }
Closure twoArgBytesRaw = { -> HashService.hashValueRaw(algorithm, null as byte[]) }
def scenarios = [threeArgString : threeArgString,
twoArgString : twoArgString,
threeArgStringRaw: threeArgStringRaw,
twoArgStringRaw : twoArgStringRaw,
twoArgBytesRaw : twoArgBytesRaw,
]
// Act
scenarios.entrySet().forEach(entry -> {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> entry.getValue().call())
assertTrue(iae.message.contains("The value cannot be null"))
})
}
@Test
void testShouldHashConstantValue() throws Exception {
// Arrange
def algorithms = HashAlgorithm.values()
final String KNOWN_VALUE = "apachenifi"
/* These values were generated using command-line tools (openssl dgst -md5, shasum [-a 1 224 256 384 512 512224 512256], rhash --sha3-224, b2sum -l 224)
* Ex: {@code $ echo -n "apachenifi" | openssl dgst -md5}
*/
final def EXPECTED_HASHES = [
md2 : "25d261790198fa543b3436b4755ded91",
md5 : "a968b5ec1d52449963dcc517789baaaf",
sha_1 : "749806dbcab91a695ac85959aca610d84f03c6a7",
sha_224 : "4933803881a4ccb9b3453b829263d3e44852765db12958267ad46135",
sha_256 : "dc4bd945723b9c234f1be408e8ceb78660b481008b8ab5b71eb2aa3b4f08357a",
sha_384 : "a5205271df448e55afc4a553e91a8fea7d60d080d390d1f3484fcb6318abe94174cf3d36ea4eb1a4d5ed7637c99dec0c",
sha_512 : "0846ae23e122fbe090e94d45f886aa786acf426f56496e816a64e292b78c1bb7a962dbfd32c5c73bbee432db400970e22fd65498c862da72a305311332c6f302",
sha_512_224: "ecf78a026035528e3097ea7289257d1819d273f60636060fbba43bfb",
sha_512_256: "d90bdd8ad7e19f2d7848a45782d5dbe056a8213a94e03d9a35d6f44dbe7ee6cd",
sha3_224 : "2e9d1ea677847dce686ca2444cc4525f114443652fcb55af4c7286cd",
sha3_256 : "b1b3cd90a21ef60caba5ec1bf12ffcb833e52a0ae26f0ab7c4f9ccfa9c5c025b",
sha3_384 : "ca699a2447032857bf4f7e84fa316264f0c1870f9330031d5d75a0770644353c268b36d0522a3cf62e60f9401aadc37c",
sha3_512 : "cb9059d9b7ec4fde4d9710160a694e7ac2a4dd9969dee43d730066ded7b80d3eefdb4cae7622d21f6cfe16092e24f1ad6ca5924767118667654cf71b7abaaca4",
blake2_160 : "7bc5a408dba4f1934d9090c4d75c65bfa0c7c90c",
blake2_256 : "40b8935dc5ed153846fb08dac8e7999ba04a74f4dab28415c39847a15c211447",
blake2_384 : "40716eddc8cfcf666d980804fed294c43fe9436a9787367a3086b45d69791fd5cef1a16c17235ea289c1e40a899b4f6b",
blake2_512 : "5f34525b130c11c469302ef6734bf6eedb1eca5d7445a3c4ae289ab58dd13ef72531966bfe2f67c4bf49c99dd14dae92d245f241482307d29bf25c45a1085026"
]
// Act
def generatedHashes = algorithms.collectEntries { HashAlgorithm algorithm ->
String hash = HashService.hashValue(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8)
logger.info("${algorithm.getName().padLeft(11)}('${KNOWN_VALUE}') [${hash.length() / 2}] = ${hash}")
[(algorithm.name), hash]
}
// Assert
generatedHashes.each { String algorithmName, String hash ->
String key = translateStringToMapKey(algorithmName)
assertEquals(EXPECTED_HASHES[key], hash)
}
}
@Test
void testShouldHashEmptyValue() throws Exception {
// Arrange
def algorithms = HashAlgorithm.values()
final String EMPTY_VALUE = ""
/* These values were generated using command-line tools (openssl dgst -md5, shasum [-a 1 224 256 384 512 512224 512256], rhash --sha3-224, b2sum -l 224)
* Ex: {@code $ echo -n "" | openssl dgst -md5}
*/
final def EXPECTED_HASHES = [
md2 : "8350e5a3e24c153df2275c9f80692773",
md5 : "d41d8cd98f00b204e9800998ecf8427e",
sha_1 : "da39a3ee5e6b4b0d3255bfef95601890afd80709",
sha_224 : "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f",
sha_256 : "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
sha_384 : "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b",
sha_512 : "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
sha_512_224: "6ed0dd02806fa89e25de060c19d3ac86cabb87d6a0ddd05c333b84f4",
sha_512_256: "c672b8d1ef56ed28ab87c3622c5114069bdd3ad7b8f9737498d0c01ecef0967a",
sha3_224 : "6b4e03423667dbb73b6e15454f0eb1abd4597f9a1b078e3f5b5a6bc7",
sha3_256 : "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
sha3_384 : "0c63a75b845e4f7d01107d852e4c2485c51a50aaaa94fc61995e71bbee983a2ac3713831264adb47fb6bd1e058d5f004",
sha3_512 : "a69f73cca23a9ac5c8b567dc185a756e97c982164fe25859e0d1dcc1475c80a615b2123af1f5f94c11e3e9402c3ac558f500199d95b6d3e301758586281dcd26",
blake2_160 : "3345524abf6bbe1809449224b5972c41790b6cf2",
blake2_256 : "0e5751c026e543b2e8ab2eb06099daa1d1e5df47778f7787faab45cdf12fe3a8",
blake2_384 : "b32811423377f52d7862286ee1a72ee540524380fda1724a6f25d7978c6fd3244a6caf0498812673c5e05ef583825100",
blake2_512 : "786a02f742015903c6c6fd852552d272912f4740e15847618a86e217f71f5419d25e1031afee585313896444934eb04b903a685b1448b755d56f701afe9be2ce"
]
// Act
def generatedHashes = algorithms.collectEntries { HashAlgorithm algorithm ->
String hash = HashService.hashValue(algorithm, EMPTY_VALUE, StandardCharsets.UTF_8)
logger.info("${algorithm.getName().padLeft(11)}('${EMPTY_VALUE}') [${hash.length() / 2}] = ${hash}")
[(algorithm.name), hash]
}
// Assert
generatedHashes.each { String algorithmName, String hash ->
String key = translateStringToMapKey(algorithmName)
assertEquals(EXPECTED_HASHES[key], hash)
}
}
@Test
void testShouldBuildHashAlgorithmAllowableValues() throws Exception {
// Arrange
final def EXPECTED_ALGORITHMS = HashAlgorithm.values()
logger.info("The consistent list of hash algorithms available [${EXPECTED_ALGORITHMS.size()}]: \n${EXPECTED_ALGORITHMS.collect { "\t${it.name}" }.join("\n")}")
// Act
def allowableValues = HashService.buildHashAlgorithmAllowableValues()
// Assert
assertInstanceOf(AllowableValue[].class, allowableValues)
def valuesList = allowableValues as List<AllowableValue>
assertEquals(EXPECTED_ALGORITHMS.size(), valuesList.size())
EXPECTED_ALGORITHMS.each { HashAlgorithm expectedAlgorithm ->
def matchingValue = valuesList.find { it.value == expectedAlgorithm.name }
assertEquals(expectedAlgorithm.name, matchingValue.displayName)
assertEquals(expectedAlgorithm.buildAllowableValueDescription(), matchingValue.description)
}
}
@Test
void testShouldBuildCharacterSetAllowableValues() throws Exception {
// Arrange
final def EXPECTED_CHARACTER_SETS = [
StandardCharsets.US_ASCII,
StandardCharsets.ISO_8859_1,
StandardCharsets.UTF_8,
StandardCharsets.UTF_16BE,
StandardCharsets.UTF_16LE,
StandardCharsets.UTF_16,
]
logger.info("The consistent list of character sets available [${EXPECTED_CHARACTER_SETS.size()}]: \n${EXPECTED_CHARACTER_SETS.collect { "\t${it.name()}" }.join("\n")}")
def expectedDescriptions =
["UTF-16": "This character set normally decodes using an optional BOM at the beginning of the data but encodes by inserting a BE BOM. For hashing, it will be replaced with UTF-16BE. "]
// Act
def allowableValues = HashService.buildCharacterSetAllowableValues()
// Assert
assertInstanceOf(AllowableValue[].class, allowableValues)
def valuesList = allowableValues as List<AllowableValue>
assertEquals(EXPECTED_CHARACTER_SETS.size(), valuesList.size())
EXPECTED_CHARACTER_SETS.each { Charset charset ->
def matchingValue = valuesList.find { it.value == charset.name() }
assertEquals(charset.name(), matchingValue.displayName)
assertEquals((expectedDescriptions[charset.name()] ?: charset.displayName()), matchingValue.description)
}
}
@Test
void testShouldHashValueFromStream() throws Exception {
// Arrange
// No command-line md2sum tool available
def algorithms = HashAlgorithm.values() - HashAlgorithm.MD2
StringBuilder sb = new StringBuilder()
10_000.times { int i ->
sb.append("${i.toString().padLeft(5)}: ${"apachenifi " * 10}\n")
}
/* These values were generated using command-line tools (openssl dgst -md5, shasum [-a 1 224 256 384 512 512224 512256], rhash --sha3-224, b2sum -l 160)
* Ex: {@code $ openssl dgst -md5 src/test/resources/HashServiceTest/largefile.txt}
*/
final def EXPECTED_HASHES = [
md5 : "8d329076847b678449610a5fb53997d2",
sha_1 : "09cd981ee7529cfd6268a69c0d53e8117e9c78b1",
sha_224 : "4d4d58c226959e0775e627a866eaa26bf18121d578b559946aea6f8c",
sha_256 : "ce50f183a8011a86c5162e94481c6b14ad921a8001746806063b3033e71440eb",
sha_384 : "62a13a410566856422f0b81b2e6ab26f91b3da1a877a5c24f681d2812f26abbc43fb637954879915b3cd9aad626ca71c",
sha_512 : "3f036116c78b1d9e2017bb1fd4b04f449839e6434c94442edebffdcdfbac1d79b483978126f0ffb12824f14ecc36a07dc95f0ba04aa68885456f3f6381471e07",
sha_512_224: "aa7227a80889366a2325801a5cfa67f29c8f272f4284aecfe5daba3c",
sha_512_256: "76faa424ee31bcb1f3a41a848806e288cb064a6bf1867881ee1b439dd8b38e40",
sha3_224 : "d4bb36bf2d00117ade2e63c6fa2ef5f6714d8b6c7a40d12623f95fd0",
sha3_256 : "f93ff4178bc7f466444a822191e152332331ba51eee42b952b3be1b46b1921f7",
sha3_384 : "7e4dfb0073645f059e5837f7c066bffd7f8b5d888b0179a8f0be6bb11c7d631847c468d4d861abcdc96503d91f2a7a78",
sha3_512 : "bf8e83f3590727e04777406e1d478615cf68468ad8690dba3f22a879e08022864a2b4ad8e8a1cbc88737578abd4b2e8493e3bda39a81af3f21fc529c1a7e3b52",
blake2_160 : "71dd4324a1f72aa10aaa59ee4d79ceee8d8915e6",
blake2_256 : "5a25864c69f42adeefc343989babb6972df38da47bb6ce712fbef4474266b539",
blake2_384 : "52417243317ca01693ba835bd5d6655c73a2f70d811b4d26ddacf9e3b74fc3993f30adc64fb6c23a6a5c1e36771a0b95",
blake2_512 : "be81dbc396a9e11c6189d2408a956466fb1c784d2d34495f9ca43434041b425675005deaeea1a04b1f44db0200b19cde5a40fd5e88414bb300620bc3d5e30f6a"
]
// Act
def generatedHashes = algorithms.collectEntries { HashAlgorithm algorithm ->
// Get a new InputStream for each iteration, or it will calculate the hash of an empty input on iterations 1 - n
InputStream input = new ByteArrayInputStream(sb.toString().bytes)
String hash = HashService.hashValueStreaming(algorithm, input)
[(algorithm.name), hash]
}
// Assert
generatedHashes.each { String algorithmName, String hash ->
String key = translateStringToMapKey(algorithmName)
assertEquals(EXPECTED_HASHES[key], hash)
}
}
/**
* Returns a {@link String} containing the hex-encoded bytes in the format "0xAB 0xCD ...".
*
* @param data the String to convert
* @param charset the {@link Charset} to use
* @return the formatted string
*/
private static String printHexBytes(String data, Charset charset) {
data.getBytes(charset).collect { "0x${Hex.toHexString([it] as byte[]).toUpperCase()}" }.join(" ")
}
private static String translateStringToMapKey(String string) {
string.toLowerCase().replaceAll(/[-\/]/, '_')
}
}

View File

@ -1,237 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.EncryptionMethod
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.crypto.Cipher
import javax.crypto.SecretKey
import javax.crypto.SecretKeyFactory
import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.PBEParameterSpec
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertEquals
class NiFiLegacyCipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(NiFiLegacyCipherProviderGroovyTest.class)
private static List<EncryptionMethod> pbeEncryptionMethods = new ArrayList<>()
private static List<EncryptionMethod> limitedStrengthPbeEncryptionMethods = new ArrayList<>()
private static final String PROVIDER_NAME = "BC"
private static final int ITERATION_COUNT = 1000
private static final byte[] SALT_16_BYTES = Hex.decodeHex("aabbccddeeff00112233445566778899".toCharArray())
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
pbeEncryptionMethods = EncryptionMethod.values().findAll { it.algorithm.toUpperCase().startsWith("PBE") }
limitedStrengthPbeEncryptionMethods = pbeEncryptionMethods.findAll { !it.isUnlimitedStrength() }
}
private static Cipher getLegacyCipher(String password, byte[] salt, String algorithm) {
try {
final PBEKeySpec pbeKeySpec = new PBEKeySpec(password.toCharArray())
final SecretKeyFactory factory = SecretKeyFactory.getInstance(algorithm, PROVIDER_NAME)
SecretKey tempKey = factory.generateSecret(pbeKeySpec)
final PBEParameterSpec parameterSpec = new PBEParameterSpec(salt, ITERATION_COUNT)
Cipher cipher = Cipher.getInstance(algorithm, PROVIDER_NAME)
cipher.init(Cipher.ENCRYPT_MODE, tempKey, parameterSpec)
return cipher
} catch (Exception e) {
logger.error("Error generating legacy cipher", e)
throw new RuntimeException(e)
}
return null
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
NiFiLegacyCipherProvider cipherProvider = new NiFiLegacyCipherProvider()
final String PASSWORD = "shortPassword"
final String plaintext = "This is a plaintext message."
// Act
for (EncryptionMethod encryptionMethod : limitedStrengthPbeEncryptionMethods) {
logger.info("Using algorithm: {}", encryptionMethod.getAlgorithm())
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(PASSWORD.length(), encryptionMethod)) {
logger.warn("This test is skipped because the password length exceeds the undocumented limit BouncyCastle imposes on a JVM with limited strength crypto policies")
continue
}
byte[] salt = cipherProvider.generateSalt(encryptionMethod)
logger.info("Generated salt ${Hex.encodeHexString(salt)} (${salt.length})")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt, true)
byte[] cipherBytes = cipher.doFinal(plaintext.getBytes("UTF-8"))
logger.info("Cipher text: {} {}", Hex.encodeHexString(cipherBytes), cipherBytes.length)
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assertEquals(plaintext, recovered)
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
NiFiLegacyCipherProvider cipherProvider = new NiFiLegacyCipherProvider()
final String PASSWORD = "shortPassword"
final String plaintext = "This is a plaintext message."
// Act
for (EncryptionMethod encryptionMethod : pbeEncryptionMethods) {
logger.info("Using algorithm: {}", encryptionMethod.getAlgorithm())
byte[] salt = cipherProvider.generateSalt(encryptionMethod)
logger.info("Generated salt ${Hex.encodeHexString(salt)} (${salt.length})")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt, true)
byte[] cipherBytes = cipher.doFinal(plaintext.getBytes("UTF-8"))
logger.info("Cipher text: {} {}", Hex.encodeHexString(cipherBytes), cipherBytes.length)
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assertEquals(plaintext, recovered)
}
}
@Test
void testGetCipherShouldSupportLegacyCode() throws Exception {
// Arrange
NiFiLegacyCipherProvider cipherProvider = new NiFiLegacyCipherProvider()
final String PASSWORD = "short"
final String plaintext = "This is a plaintext message."
// Act
for (EncryptionMethod encryptionMethod : limitedStrengthPbeEncryptionMethods) {
logger.info("Using algorithm: {}", encryptionMethod.getAlgorithm())
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(PASSWORD.length(), encryptionMethod)) {
logger.warn("This test is skipped because the password length exceeds the undocumented limit BouncyCastle imposes on a JVM with limited strength crypto policies")
continue
}
byte[] salt = cipherProvider.generateSalt(encryptionMethod)
logger.info("Generated salt ${Hex.encodeHexString(salt)} (${salt.length})")
// Initialize a legacy cipher for encryption
Cipher legacyCipher = getLegacyCipher(PASSWORD, salt, encryptionMethod.getAlgorithm())
byte[] cipherBytes = legacyCipher.doFinal(plaintext.getBytes("UTF-8"))
logger.info("Cipher text: {} {}", Hex.encodeHexString(cipherBytes), cipherBytes.length)
Cipher providedCipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt, false)
byte[] recoveredBytes = providedCipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assertEquals(plaintext, recovered)
}
}
@Test
void testGetCipherWithoutSaltShouldSupportLegacyCode() throws Exception {
// Arrange
NiFiLegacyCipherProvider cipherProvider = new NiFiLegacyCipherProvider()
final String PASSWORD = "short"
final byte[] SALT = new byte[0]
final String plaintext = "This is a plaintext message."
// Act
for (EncryptionMethod em : limitedStrengthPbeEncryptionMethods) {
logger.info("Using algorithm: {}", em.getAlgorithm())
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(PASSWORD.length(), em)) {
logger.warn("This test is skipped because the password length exceeds the undocumented limit BouncyCastle imposes on a JVM with limited strength crypto policies")
continue
}
// Initialize a legacy cipher for encryption
Cipher legacyCipher = getLegacyCipher(PASSWORD, SALT, em.getAlgorithm())
byte[] cipherBytes = legacyCipher.doFinal(plaintext.getBytes("UTF-8"))
logger.info("Cipher text: {} {}", Hex.encodeHexString(cipherBytes), cipherBytes.length)
Cipher providedCipher = cipherProvider.getCipher(em, PASSWORD, false)
byte[] recoveredBytes = providedCipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assertEquals(plaintext, recovered)
}
}
@Test
void testGetCipherShouldIgnoreKeyLength() throws Exception {
// Arrange
NiFiLegacyCipherProvider cipherProvider = new NiFiLegacyCipherProvider()
final String PASSWORD = "shortPassword"
final byte[] SALT = SALT_16_BYTES
final String plaintext = "This is a plaintext message."
final def KEY_LENGTHS = [-1, 40, 64, 128, 192, 256]
// Initialize a cipher for encryption
EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES
final Cipher cipher128 = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, true)
byte[] cipherBytes = cipher128.doFinal(plaintext.getBytes("UTF-8"))
logger.info("Cipher text: {} {}", Hex.encodeHexString(cipherBytes), cipherBytes.length)
// Act
KEY_LENGTHS.each { int keyLength ->
logger.info("Decrypting with 'requested' key length: ${keyLength}")
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, keyLength, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assertEquals(plaintext, recovered)
}
}
}

View File

@ -1,306 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.EncryptionMethod
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.crypto.Cipher
import javax.crypto.SecretKey
import javax.crypto.SecretKeyFactory
import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.PBEParameterSpec
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assertions.fail
class OpenSSLPKCS5CipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(OpenSSLPKCS5CipherProviderGroovyTest.class)
private static List<EncryptionMethod> pbeEncryptionMethods = new ArrayList<>()
private static List<EncryptionMethod> limitedStrengthPbeEncryptionMethods = new ArrayList<>()
private static final String PROVIDER_NAME = "BC"
private static final int ITERATION_COUNT = 0
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
pbeEncryptionMethods = EncryptionMethod.values().findAll { it.algorithm.toUpperCase().startsWith("PBE") }
limitedStrengthPbeEncryptionMethods = pbeEncryptionMethods.findAll { !it.isUnlimitedStrength() }
}
private static Cipher getLegacyCipher(String password, byte[] salt, String algorithm) {
try {
final PBEKeySpec pbeKeySpec = new PBEKeySpec(password.toCharArray())
final SecretKeyFactory factory = SecretKeyFactory.getInstance(algorithm, PROVIDER_NAME)
SecretKey tempKey = factory.generateSecret(pbeKeySpec)
final PBEParameterSpec parameterSpec = new PBEParameterSpec(salt, ITERATION_COUNT)
Cipher cipher = Cipher.getInstance(algorithm, PROVIDER_NAME)
cipher.init(Cipher.ENCRYPT_MODE, tempKey, parameterSpec)
return cipher
} catch (Exception e) {
logger.error("Error generating legacy cipher", e)
fail(e.getMessage())
}
return null
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider()
final String PASSWORD = "short"
final byte[] SALT = Hex.decodeHex("aabbccddeeff0011".toCharArray())
final String plaintext = "This is a plaintext message."
// Act
for (EncryptionMethod em : limitedStrengthPbeEncryptionMethods) {
logger.info("Using algorithm: {}", em.getAlgorithm())
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(PASSWORD.length(), em)) {
logger.warn("This test is skipped because the password length exceeds the undocumented limit BouncyCastle imposes on a JVM with limited strength crypto policies")
continue
}
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, true)
byte[] cipherBytes = cipher.doFinal(plaintext.getBytes("UTF-8"))
logger.info("Cipher text: {} {}", Hex.encodeHexString(cipherBytes), cipherBytes.length)
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assertEquals(plaintext, recovered)
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider()
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex("aabbccddeeff0011".toCharArray())
final String plaintext = "This is a plaintext message."
// Act
for (EncryptionMethod em : pbeEncryptionMethods) {
logger.info("Using algorithm: {}", em.getAlgorithm())
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, true)
byte[] cipherBytes = cipher.doFinal(plaintext.getBytes("UTF-8"))
logger.info("Cipher text: {} {}", Hex.encodeHexString(cipherBytes), cipherBytes.length)
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assertEquals(plaintext, recovered)
}
}
@Test
void testGetCipherShouldSupportLegacyCode() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider()
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex("0011223344556677".toCharArray())
final String plaintext = "This is a plaintext message."
// Act
for (EncryptionMethod em : limitedStrengthPbeEncryptionMethods) {
logger.info("Using algorithm: {}", em.getAlgorithm())
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(PASSWORD.length(), em)) {
logger.warn("This test is skipped because the password length exceeds the undocumented limit BouncyCastle imposes on a JVM with limited strength crypto policies")
continue
}
// Initialize a legacy cipher for encryption
Cipher legacyCipher = getLegacyCipher(PASSWORD, SALT, em.getAlgorithm())
byte[] cipherBytes = legacyCipher.doFinal(plaintext.getBytes("UTF-8"))
logger.info("Cipher text: {} {}", Hex.encodeHexString(cipherBytes), cipherBytes.length)
Cipher providedCipher = cipherProvider.getCipher(em, PASSWORD, SALT, false)
byte[] recoveredBytes = providedCipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assertEquals(plaintext, recovered)
}
}
@Test
void testGetCipherWithoutSaltShouldSupportLegacyCode() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider()
final String PASSWORD = "short"
final byte[] SALT = new byte[0]
final String plaintext = "This is a plaintext message."
// Act
for (EncryptionMethod em : limitedStrengthPbeEncryptionMethods) {
logger.info("Using algorithm: {}", em.getAlgorithm())
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(PASSWORD.length(), em)) {
logger.warn("This test is skipped because the password length exceeds the undocumented limit BouncyCastle imposes on a JVM with limited strength crypto policies")
continue
}
// Initialize a legacy cipher for encryption
Cipher legacyCipher = getLegacyCipher(PASSWORD, SALT, em.getAlgorithm())
byte[] cipherBytes = legacyCipher.doFinal(plaintext.getBytes("UTF-8"))
logger.info("Cipher text: {} {}", Hex.encodeHexString(cipherBytes), cipherBytes.length)
Cipher providedCipher = cipherProvider.getCipher(em, PASSWORD, false)
byte[] recoveredBytes = providedCipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assertEquals(plaintext, recovered)
}
}
@Test
void testGetCipherShouldIgnoreKeyLength() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider()
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex("aabbccddeeff0011".toCharArray())
final String plaintext = "This is a plaintext message."
final def KEY_LENGTHS = [-1, 40, 64, 128, 192, 256]
// Initialize a cipher for encryption
EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES
final Cipher cipher128 = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, true)
byte[] cipherBytes = cipher128.doFinal(plaintext.getBytes("UTF-8"))
logger.info("Cipher text: {} {}", Hex.encodeHexString(cipherBytes), cipherBytes.length)
// Act
KEY_LENGTHS.each { int keyLength ->
logger.info("Decrypting with 'requested' key length: ${keyLength}")
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, keyLength, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assertEquals(plaintext, recovered)
}
}
@Test
void testGetCipherShouldRequireEncryptionMethod() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider()
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex("0011223344556677".toCharArray())
// Act
logger.info("Using algorithm: null")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(null, PASSWORD, SALT, false))
// Assert
assertTrue(iae.getMessage().contains("The encryption method must be specified"))
}
@Test
void testGetCipherShouldRequirePassword() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider()
final byte[] SALT = Hex.decodeHex("0011223344556677".toCharArray())
EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES
// Act
logger.info("Using algorithm: ${encryptionMethod}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, "", SALT, false))
// Assert
assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"))
}
@Test
void testGetCipherShouldValidateSaltLength() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider()
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex("00112233445566".toCharArray())
EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES
// Act
logger.info("Using algorithm: ${encryptionMethod}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, false))
// Assert
assertTrue(iae.getMessage().contains("Salt must be 8 bytes US-ASCII encoded"))
}
@Test
void testGenerateSaltShouldProvideValidSalt() throws Exception {
// Arrange
PBECipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider()
// Act
byte[] salt = cipherProvider.generateSalt()
logger.info("Checking salt ${Hex.encodeHexString(salt)}")
// Assert
assertEquals(cipherProvider.getDefaultSaltLength(), salt.length)
byte [] notExpected = new byte [cipherProvider.defaultSaltLength]
Arrays.fill(notExpected, 0x00 as byte)
assertFalse(Arrays.equals(notExpected, salt))
}
}

View File

@ -1,543 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.EncryptionMethod
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.crypto.Cipher
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class PBKDF2CipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(PBKDF2CipherProviderGroovyTest.class)
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess"
private static List<EncryptionMethod> strongKDFEncryptionMethods
public static final String MICROBENCHMARK = "microbenchmark"
private static final int DEFAULT_KEY_LENGTH = 128
private static final int TEST_ITERATION_COUNT = 1000
private final String DEFAULT_PRF = "SHA-512"
private final String SALT_HEX = "0123456789ABCDEFFEDCBA9876543210"
private final String IV_HEX = "01" * 16
private static ArrayList<Integer> AES_KEY_LENGTHS
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
strongKDFEncryptionMethods = EncryptionMethod.values().findAll { it.isCompatibleWithStrongKDFs() }
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
AES_KEY_LENGTHS = [128, 192, 256]
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT)
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherShouldRejectInvalidIV() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT)
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
final def INVALID_IVS = (0..15).collect { int length -> new byte[length] }
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
INVALID_IVS.each { byte[] badIV ->
logger.info("IV: ${Hex.encodeHexString(badIV)} ${badIV.length}")
// Encrypt should print a warning about the bad IV but overwrite it
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, true)
// Decrypt should fail
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false))
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@Test
void testGetCipherWithExternalIVShouldBeInternallyConsistent() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT)
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
final byte[] IV = Hex.decodeHex(IV_HEX as char[])
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT)
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
final int LONG_KEY_LENGTH = 256
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, LONG_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, iv, LONG_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testShouldRejectEmptyPRF() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
final byte[] IV = Hex.decodeHex(IV_HEX as char[])
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
String prf = ""
// Act
logger.info("Using PRF ${prf}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> new PBKDF2CipherProvider(prf, TEST_ITERATION_COUNT))
// Assert
assertTrue(iae.getMessage().contains("Cannot resolve empty PRF"))
}
@Test
void testShouldResolveDefaultPRF() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
final byte[] IV = Hex.decodeHex(IV_HEX as char[])
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
final PBKDF2CipherProvider SHA512_PROVIDER = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT)
String prf = "sha768"
logger.info("Using ${prf}")
// Act
cipherProvider = new PBKDF2CipherProvider(prf, TEST_ITERATION_COUNT)
logger.info("Resolved PRF to ${cipherProvider.getPRFName()}")
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = SHA512_PROVIDER.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
@Test
void testShouldResolveVariousPRFs() throws Exception {
// Arrange
final List<String> PRFS = ["SHA-1", "MD5", "SHA-256", "SHA-384", "SHA-512"]
RandomIVPBECipherProvider cipherProvider
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
final byte[] IV = Hex.decodeHex(IV_HEX as char[])
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
PRFS.each { String prf ->
logger.info("Using ${prf}")
cipherProvider = new PBKDF2CipherProvider(prf, TEST_ITERATION_COUNT)
logger.info("Resolved PRF to ${cipherProvider.getPRFName()}")
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherShouldSupportExternalCompatibility() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider("SHA-256", TEST_ITERATION_COUNT)
final String PLAINTEXT = "This is a plaintext message."
final String PASSWORD = "thisIsABadPassword"
// These values can be generated by running `$ ./openssl_pbkdf2.rb` in the terminal
final byte[] SALT = Hex.decodeHex("ae2481bee3d8b5d5b732bf464ea2ff01" as char[])
final byte[] IV = Hex.decodeHex("26db997dcd18472efd74dabe5ff36853" as char[])
final String CIPHER_TEXT = "92edbabae06add6275a1d64815755a9ba52afc96e2c1a316d3abbe1826e96f6c"
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT as char[])
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
@Test
void testGetCipherShouldHandleDifferentPRFs() throws Exception {
// Arrange
RandomIVPBECipherProvider sha256CP = new PBKDF2CipherProvider("SHA-256", TEST_ITERATION_COUNT)
RandomIVPBECipherProvider sha512CP = new PBKDF2CipherProvider("SHA-512", TEST_ITERATION_COUNT)
final String PASSWORD = "thisIsABadPassword"
final byte[] SALT = [0x11] * 16
final byte[] IV = [0x22] * 16
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
Cipher sha256Cipher = sha256CP.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
byte[] sha256CipherBytes = sha256Cipher.doFinal(PLAINTEXT.bytes)
Cipher sha512Cipher = sha512CP.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
byte[] sha512CipherBytes = sha512Cipher.doFinal(PLAINTEXT.bytes)
// Assert
assertFalse(Arrays.equals(sha512CipherBytes, sha256CipherBytes))
Cipher sha256DecryptCipher = sha256CP.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] sha256RecoveredBytes = sha256DecryptCipher.doFinal(sha256CipherBytes)
assertArrayEquals(PLAINTEXT.bytes, sha256RecoveredBytes)
Cipher sha512DecryptCipher = sha512CP.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] sha512RecoveredBytes = sha512DecryptCipher.doFinal(sha512CipherBytes)
assertArrayEquals(PLAINTEXT.bytes, sha512RecoveredBytes)
}
@Test
void testGetCipherForDecryptShouldRequireIV() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT)
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
final byte[] IV = Hex.decodeHex(IV_HEX as char[])
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
// Assert
assertTrue(iae.getMessage().contains( "Cannot decrypt without a valid IV"))
}
}
@Test
void testGetCipherShouldRejectInvalidSalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT)
final String PASSWORD = "thisIsABadPassword"
final def INVALID_SALTS = ['pbkdf2', '$3a$11$', 'x', '$2a$10$', '', null]
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt?.bytes, DEFAULT_KEY_LENGTH, true))
// Assert
assertTrue(iae.getMessage().contains("The salt must be at least 16 bytes. To generate a salt, use PBKDF2CipherProvider#generateSalt"))
}
}
@Test
void testGetCipherShouldAcceptValidKeyLengths() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT)
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
final byte[] IV = Hex.decodeHex(IV_HEX as char[])
// Currently only AES ciphers are compatible with PBKDF2, so redundant to test all algorithms
final def VALID_KEY_LENGTHS = AES_KEY_LENGTHS
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
VALID_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherShouldNotAcceptInvalidKeyLengths() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT)
final String PASSWORD = "shortPassword"
final byte[] SALT = Hex.decodeHex(SALT_HEX as char[])
final byte[] IV = Hex.decodeHex(IV_HEX as char[])
// Currently only AES ciphers are compatible with PBKDF2, so redundant to test all algorithms
final def VALID_KEY_LENGTHS = [-1, 40, 64, 112, 512]
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
VALID_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
// Assert
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
}
}
@EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true")
@Test
void testDefaultConstructorShouldProvideStrongIterationCount() {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider()
// Values taken from http://wildlyinaccurate.com/bcrypt-choosing-a-work-factor/ and http://security.stackexchange.com/questions/17207/recommended-of-rounds-for-bcrypt
// Calculate the iteration count to reach 500 ms
int minimumIterationCount = calculateMinimumIterationCount()
logger.info("Determined minimum safe iteration count to be ${minimumIterationCount}")
// Act
int iterationCount = cipherProvider.getIterationCount()
logger.info("Default iteration count ${iterationCount}")
// Assert
assertTrue("The default iteration count for PBKDF2CipherProvider is too weak. Please update the default value to a stronger level.", iterationCount >= minimumIterationCount)
}
/**
* Returns the iteration count required for a derivation to exceed 500 ms on this machine using the default PRF.
* Code adapted from http://security.stackexchange.com/questions/17207/recommended-of-rounds-for-bcrypt
*
* @return the minimum iteration count
*/
private static int calculateMinimumIterationCount() {
// High start-up cost, so run multiple times for better benchmarking
final int RUNS = 10
// Benchmark using an iteration count of 10k
int iterationCount = 10_000
final byte[] SALT = [0x00 as byte] * 16
final byte[] IV = [0x01 as byte] * 16
String defaultPrf = new PBKDF2CipherProvider().getPRFName()
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(defaultPrf, iterationCount)
// Run once to prime the system
double duration = time {
Cipher cipher = cipherProvider.getCipher(EncryptionMethod.AES_CBC, MICROBENCHMARK, SALT, IV, DEFAULT_KEY_LENGTH, false)
}
logger.info("First run of iteration count ${iterationCount} took ${duration} ms (ignored)")
def durations = []
RUNS.times { int i ->
duration = time {
// Use encrypt mode with provided salt and IV to minimize overhead during benchmark call
Cipher cipher = cipherProvider.getCipher(EncryptionMethod.AES_CBC, "${MICROBENCHMARK}${i}", SALT, IV, DEFAULT_KEY_LENGTH, false)
}
logger.info("Iteration count ${iterationCount} took ${duration} ms")
durations << duration
}
duration = durations.sum() / durations.size()
logger.info("Iteration count ${iterationCount} averaged ${duration} ms")
// Keep increasing iteration count until the estimated duration is over 500 ms
while (duration < 500) {
iterationCount *= 2
duration *= 2
}
logger.info("Returning iteration count ${iterationCount} for ${duration} ms")
return iterationCount
}
private static double time(Closure c) {
long start = System.nanoTime()
c.call()
long end = System.nanoTime()
return (end - start) / 1_000_000.0
}
@Test
void testGenerateSaltShouldProvideValidSalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT)
// Act
byte[] salt = cipherProvider.generateSalt()
logger.info("Checking salt ${Hex.encodeHexString(salt)}")
// Assert
assertEquals(16,salt.length )
byte [] notExpected = new byte[16]
Arrays.fill(notExpected, 0x00 as byte)
assertFalse(Arrays.equals(notExpected, salt))
}
}

View File

@ -1,679 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.apache.commons.codec.binary.Base64
import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.EncryptionMethod
import org.apache.nifi.security.util.crypto.scrypt.Scrypt
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.crypto.Cipher
import javax.crypto.SecretKey
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
import java.security.SecureRandom
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class ScryptCipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(ScryptCipherProviderGroovyTest.class)
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess"
private static List<EncryptionMethod> strongKDFEncryptionMethods
private static final int DEFAULT_KEY_LENGTH = 128
public static final String MICROBENCHMARK = "microbenchmark"
private static ArrayList<Integer> AES_KEY_LENGTHS
RandomIVPBECipherProvider cipherProvider
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
strongKDFEncryptionMethods = EncryptionMethod.values().findAll { it.isCompatibleWithStrongKDFs() }
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
AES_KEY_LENGTHS = [128, 192, 256]
}
@BeforeEach
void setUp() throws Exception {
// Very fast parameters to test for correctness rather than production values
cipherProvider = new ScryptCipherProvider(4, 1, 1)
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherWithExternalIVShouldBeInternallyConsistent() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("01" * 16 as char[])
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final int LONG_KEY_LENGTH = 256
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, LONG_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, iv, LONG_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testScryptShouldSupportExternalCompatibility() throws Exception {
// Arrange
// Default values are N=2^14, r=8, p=1, but the provided salt will contain the parameters used
cipherProvider = new ScryptCipherProvider()
final String PLAINTEXT = "This is a plaintext message."
final String PASSWORD = "thisIsABadPassword"
final int DK_LEN = 128
// These values can be generated by running `$ ./openssl_scrypt.rb` in the terminal
final byte[] SALT = Hex.decodeHex("f5b8056ea6e66edb8d013ac432aba24a" as char[])
logger.info("Expected salt: ${Hex.encodeHexString(SALT)}")
final byte[] IV = Hex.decodeHex("76a00f00878b8c3db314ae67804c00a1" as char[])
final String CIPHER_TEXT = "604188bf8e9137bc1b24a0ab01973024bc5935e9ae5fedf617bdca028c63c261"
logger.sanity("Ruby cipher text: ${CIPHER_TEXT}")
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT as char[])
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Sanity check
String rubyKeyHex = "a8efbc0a709d3f89b6bb35b05fc8edf5"
logger.sanity("Using key: ${rubyKeyHex}")
logger.sanity("Using IV: ${Hex.encodeHexString(IV)}")
Cipher rubyCipher = Cipher.getInstance(encryptionMethod.algorithm, "BC")
def rubyKey = new SecretKeySpec(Hex.decodeHex(rubyKeyHex as char[]), "AES")
def ivSpec = new IvParameterSpec(IV)
rubyCipher.init(Cipher.ENCRYPT_MODE, rubyKey, ivSpec)
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes)
logger.sanity("Created cipher text: ${Hex.encodeHexString(rubyCipherBytes)}")
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec)
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(rubyCipherBytes))
logger.sanity("Decrypted generated cipher text successfully")
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(cipherBytes))
logger.sanity("Decrypted external cipher text successfully")
// n$r$p$hex_salt_SL$hex_hash_HL
final String FULL_HASH = "400\$8\$24\$f5b8056ea6e66edb8d013ac432aba24a\$a8efbc0a709d3f89b6bb35b05fc8edf5"
logger.info("Full Hash: ${FULL_HASH}")
def (String nStr, String rStr, String pStr, String saltHex, String hashHex) = FULL_HASH.split("\\\$")
def (n, r, p) = [nStr, rStr, pStr].collect { Integer.valueOf(it, 16) }
logger.info("N: Hex ${nStr} -> ${n}")
logger.info("r: Hex ${rStr} -> ${r}")
logger.info("p: Hex ${pStr} -> ${p}")
logger.info("Salt: ${saltHex}")
logger.info("Hash: ${hashHex}")
// Form Java-style salt with cost params from Ruby-style
String javaSalt = Scrypt.formatSalt(Hex.decodeHex(saltHex as char[]), n, r, p)
logger.info("Formed Java-style salt: ${javaSalt}")
// Convert hash from hex to Base64
String base64Hash = CipherUtility.encodeBase64NoPadding(Hex.decodeHex(hashHex as char[]))
logger.info("Converted hash from hex ${hashHex} to Base64 ${base64Hash}")
assertEquals(hashHex, Hex.encodeHexString(Base64.decodeBase64(base64Hash)))
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
logger.info("External cipher text: ${CIPHER_TEXT} ${cipherBytes.length}")
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, javaSalt.bytes, IV, DK_LEN, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
@Test
void testGetCipherShouldHandleSaltWithoutParameters() throws Exception {
// Arrange
// To help Groovy resolve implementation private methods not known at interface level
cipherProvider = cipherProvider as ScryptCipherProvider
final String PASSWORD = "shortPassword"
final byte[] SALT = new byte[cipherProvider.defaultSaltLength]
new SecureRandom().nextBytes(SALT)
// final byte[] SALT = [0x00] * 16 as byte[]
final String EXPECTED_FORMATTED_SALT = cipherProvider.formatSaltForScrypt(SALT)
logger.info("Expected salt: ${EXPECTED_FORMATTED_SALT}")
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
// Manually initialize a cipher for decrypt with the expected salt
byte[] parsedSalt = new byte[cipherProvider.defaultSaltLength]
def params = []
cipherProvider.parseSalt(EXPECTED_FORMATTED_SALT, parsedSalt, params)
def (int n, int r, int p) = params
byte[] keyBytes = Scrypt.deriveScryptKey(PASSWORD.bytes, parsedSalt, n, r, p, DEFAULT_KEY_LENGTH)
logger.info("Manually derived key bytes: ${Hex.encodeHexString(keyBytes)}")
SecretKey key = new SecretKeySpec(keyBytes, "AES")
Cipher manualCipher = Cipher.getInstance(encryptionMethod.algorithm, encryptionMethod.provider)
manualCipher.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv))
byte[] recoveredBytes = manualCipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
@Test
void testGetCipherShouldNotAcceptInvalidSalts() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword"
final def INVALID_SALTS = ['bad_sal', '$3a$11$', 'x', '$2a$10$']
final LENGTH_MESSAGE = "The raw salt must be greater than or equal to 8 bytes"
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true))
logger.warn(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains(LENGTH_MESSAGE))
}
}
@Test
void testGetCipherShouldHandleUnformattedSalts() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword"
final def RECOVERABLE_SALTS = ['$ab$00$acbdefghijklmnopqrstuv', '$4$1$1$0123456789abcdef', '$400$1$1$abcdefghijklmnopqrstuv']
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
RECOVERABLE_SALTS.each { String salt ->
logger.info("Checking salt ${salt}")
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true)
// Assert
assertNotNull(cipher)
}
}
@Test
void testGetCipherShouldRejectEmptySalt() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword"
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true))
logger.warn(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains("The salt cannot be empty. To generate a salt, use ScryptCipherProvider#generateSalt"))
}
@Test
void testGetCipherForDecryptShouldRequireIV() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("00" * 16 as char[])
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
logger.warn(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@Test
void testGetCipherShouldAcceptValidKeyLengths() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("01" * 16 as char[])
final def VALID_KEY_LENGTHS = AES_KEY_LENGTHS
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
VALID_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true)
logger.info("IV: ${Hex.encodeHexString(IV)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
// Assert
assertEquals(PLAINTEXT, recovered)
}
}
@Test
void testGetCipherShouldNotAcceptInvalidKeyLengths() throws Exception {
// Arrange
final String PASSWORD = "shortPassword"
final byte[] SALT = cipherProvider.generateSalt()
final byte[] IV = Hex.decodeHex("00" * 16 as char[])
// Even though Scrypt can derive keys of arbitrary length, it will fail to validate if the underlying cipher does not support it
final def INVALID_KEY_LENGTHS = [-1, 40, 64, 112, 512]
// Currently only AES ciphers are compatible with Scrypt, so redundant to test all algorithms
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
INVALID_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
logger.warn(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
}
}
@Test
void testScryptShouldNotAcceptInvalidPassword() {
// Arrange
String badPassword = ""
byte[] salt = [0x01 as byte] * 16
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
()-> cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true))
// Assert
assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"))
}
@Test
void testGenerateSaltShouldUseProvidedParameters() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new ScryptCipherProvider(8, 2, 2)
int n = cipherProvider.getN()
int r = cipherProvider.getR()
int p = cipherProvider.getP()
// Act
final String salt = new String(cipherProvider.generateSalt())
logger.info("Salt: ${salt}")
// Assert
assertTrue((salt =~ "^(?i)\\\$s0\\\$[a-f0-9]{5,16}\\\$").find())
String params = Scrypt.encodeParams(n, r, p)
assertTrue(salt.contains("\$${params}\$"))
}
@Test
void testShouldParseSalt() throws Exception {
// Arrange
cipherProvider = cipherProvider as ScryptCipherProvider
final byte[] EXPECTED_RAW_SALT = Hex.decodeHex("f5b8056ea6e66edb8d013ac432aba24a" as char[])
final int EXPECTED_N = 1024
final int EXPECTED_R = 8
final int EXPECTED_P = 36
final String FORMATTED_SALT = "\$s0\$a0824\$9bgFbqbmbtuNATrEMquiSg"
logger.info("Using salt: ${FORMATTED_SALT}")
byte[] rawSalt = new byte[16]
def params = []
// Act
cipherProvider.parseSalt(FORMATTED_SALT, rawSalt, params)
// Assert
assertArrayEquals(EXPECTED_RAW_SALT, rawSalt)
assertEquals(EXPECTED_N, params[0])
assertEquals(EXPECTED_R, params[1])
assertEquals(EXPECTED_P, params[2])
}
@Test
void testShouldVerifyPBoundary() throws Exception {
// Arrange
final int r = 8
final int p = 1
// Act
boolean valid = ScryptCipherProvider.isPValid(r, p)
// Assert
assertTrue(valid)
}
@Test
void testShouldFailPBoundary() throws Exception {
// Arrange
// The p upper bound is calculated with the formula below, when r = 8:
// pBoundary = ((Math.pow(2,32))-1) * (32.0/(r * 128)), where pBoundary = 134217727.96875;
Map<Integer, Integer> costParameters = [8:134217729, 128:8388608, 4096: 0]
// Act and Assert
costParameters.entrySet().forEach(entry -> {
assertFalse(ScryptCipherProvider.isPValid(entry.getKey(), entry.getValue()))
})
}
@Test
void testShouldVerifyRValue() throws Exception {
// Arrange
final int r = 8
// Act
boolean valid = ScryptCipherProvider.isRValid(r)
// Assert
assertTrue(valid)
}
@Test
void testShouldFailRValue() throws Exception {
// Arrange
final int r = 0
// Act
boolean valid = ScryptCipherProvider.isRValid(r)
// Assert
assertFalse(valid)
}
@Test
void testShouldValidateScryptCipherProviderPBoundary() throws Exception {
// Arrange
final int n = 64
final int r = 8
final int p = 1
// Act
ScryptCipherProvider testCipherProvider = new ScryptCipherProvider(n, r, p)
// Assert
assertNotNull(testCipherProvider)
}
@Test
void testShouldCatchInvalidP() throws Exception {
// Arrange
final int n = 64
final int r = 8
final int p = 0
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> new ScryptCipherProvider(n, r, p))
logger.warn(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains("Invalid p value exceeds p boundary"))
}
@Test
void testShouldCatchInvalidR() throws Exception {
// Arrange
final int n = 64
final int r = 0
final int p = 0
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> new ScryptCipherProvider(n, r, p))
logger.warn(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains("Invalid r value; must be greater than 0"))
}
@Test
void testShouldAcceptFormattedSaltWithPlus() throws Exception {
// Arrange
final String FULL_SALT_WITH_PLUS = "\$s0\$e0801\$smJD8vwWI3+uQCHYz2yg0+"
// Act
boolean isScryptSalt = ScryptCipherProvider.isScryptFormattedSalt(FULL_SALT_WITH_PLUS)
logger.info("Is Scrypt salt: ${isScryptSalt}")
// Assert
assertTrue(isScryptSalt)
}
@EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true",
disabledReason = "This test can be run on a specific machine to evaluate if the default parameters are sufficient")
@Test
void testDefaultConstructorShouldProvideStrongParameters() {
// Arrange
ScryptCipherProvider testCipherProvider = new ScryptCipherProvider()
/** See this Stack Overflow answer for a good visualization of the interplay between N, r, p <a href="http://stackoverflow.com/a/30308723" rel="noopener">http://stackoverflow.com/a/30308723</a> */
// Act
int n = testCipherProvider.getN()
int r = testCipherProvider.getR()
int p = testCipherProvider.getP()
logger.info("Default parameters N=${n}, r=${r}, p=${p}")
// Calculate the parameters to reach 500 ms
def (int minimumN, int minimumR, int minimumP) = calculateMinimumParameters(r, p)
logger.info("Determined minimum safe parameters to be N=${minimumN}, r=${minimumR}, p=${minimumP}")
// Assert
assertTrue(n >= minimumN, "The default parameters for ScryptCipherProvider are too weak. Please update the default values to a stronger level.")
}
/**
* Returns the parameters required for a derivation to exceed 500 ms on this machine. Code adapted from http://security.stackexchange.com/questions/17207/recommended-of-rounds-for-bcrypt
*
* @param r the block size in bytes (defaults to 8)
* @param p the parallelization factor (defaults to 1)
* @param maxHeapSize the maximum heap size to use in bytes (defaults to 1 GB)
*
* @return the minimum scrypt parameters as [N, r, p]
*/
private static List<Integer> calculateMinimumParameters(int r = 8, int p = 1, int maxHeapSize = 1024 * 1024 * 1024) {
// High start-up cost, so run multiple times for better benchmarking
final int RUNS = 10
// Benchmark using N=2^4
int n = 2**4
int dkLen = 128
assertTrue(Scrypt.calculateExpectedMemory(n, r, p) <= maxHeapSize)
byte[] salt = new byte[Scrypt.defaultSaltLength]
new SecureRandom().nextBytes(salt)
// Run once to prime the system
double duration = time {
Scrypt.scrypt(MICROBENCHMARK, salt, n, r, p, dkLen)
}
logger.info("First run of N=${n}, r=${r}, p=${p} took ${duration} ms (ignored)")
def durations = []
RUNS.times { int i ->
duration = time {
Scrypt.scrypt(MICROBENCHMARK, salt, n, r, p, dkLen)
}
logger.info("N=${n}, r=${r}, p=${p} took ${duration} ms")
durations << duration
}
duration = durations.sum() / durations.size()
logger.info("N=${n}, r=${r}, p=${p} averaged ${duration} ms")
// Doubling N would double the run time
// Keep increasing N until the estimated duration is over 500 ms
while (duration < 500) {
n *= 2
duration *= 2
}
logger.info("Returning N=${n}, r=${r}, p=${p} for ${duration} ms")
return [n, r, p]
}
private static double time(Closure c) {
long start = System.nanoTime()
c.call()
long end = System.nanoTime()
return (end - start) / 1_000_000.0
}
}

View File

@ -1,377 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.bouncycastle.util.encoders.Hex
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import java.nio.charset.StandardCharsets
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class ScryptSecureHasherTest {
@Test
void testShouldBeDeterministicWithStaticSalt() {
// Arrange
int n = 1024
int r = 8
int p = 2
int dkLength = 32
int testIterations = 10
byte[] inputBytes = "This is a sensitive value".bytes
final String EXPECTED_HASH_HEX = "a67fd2f4b3aa577b8ecdb682e60b4451a84611dcbbc534bce17616056ef8965d"
ScryptSecureHasher scryptSH = new ScryptSecureHasher(n, r, p, dkLength)
def results = []
// Act
testIterations.times { int i ->
byte[] hash = scryptSH.hashRaw(inputBytes)
String hashHex = new String(Hex.encode(hash))
results << hashHex
}
// Assert
results.forEach( result -> assertEquals(EXPECTED_HASH_HEX, result))
}
@Test
void testShouldBeDifferentWithRandomSalt() {
// Arrange
int n = 1024
int r = 8
int p = 2
int dkLength = 128
int testIterations = 10
byte[] inputBytes = "This is a sensitive value".bytes
final String EXPECTED_HASH_HEX = "a67fd2f4b3aa577b8ecdb682e60b4451"
ScryptSecureHasher scryptSH = new ScryptSecureHasher(n, r, p, dkLength, 16)
def results = []
// Act
testIterations.times { int i ->
byte[] hash = scryptSH.hashRaw(inputBytes)
String hashHex = new String(Hex.encode(hash))
results << hashHex
}
// Assert
assertTrue(results.unique().size() == results.size())
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
}
@Test
void testShouldHandleArbitrarySalt() {
// Arrange
int n = 1024
int r = 8
int p = 2
int dkLength = 32
def input = "This is a sensitive value"
byte[] inputBytes = input.bytes
final String EXPECTED_HASH_HEX = "a67fd2f4b3aa577b8ecdb682e60b4451a84611dcbbc534bce17616056ef8965d"
final String EXPECTED_HASH_BASE64 = "pn/S9LOqV3uOzbaC5gtEUahGEdy7xTS84XYWBW74ll0"
final byte[] EXPECTED_HASH_BYTES = Hex.decode(EXPECTED_HASH_HEX)
// Static salt instance
ScryptSecureHasher staticSaltHasher = new ScryptSecureHasher(n, r, p, dkLength)
ScryptSecureHasher arbitrarySaltHasher = new ScryptSecureHasher(n, r, p, dkLength, 16)
final byte[] STATIC_SALT = AbstractSecureHasher.STATIC_SALT
final String DIFFERENT_STATIC_SALT = "Diff Static Salt"
// Act
byte[] staticSaltHash = staticSaltHasher.hashRaw(inputBytes)
byte[] arbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, STATIC_SALT)
byte[] differentArbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, DIFFERENT_STATIC_SALT.getBytes(StandardCharsets.UTF_8))
byte[] differentSaltHash = arbitrarySaltHasher.hashRaw(inputBytes)
String staticSaltHashHex = staticSaltHasher.hashHex(input)
String arbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8))
String differentArbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, DIFFERENT_STATIC_SALT)
String differentSaltHashHex = arbitrarySaltHasher.hashHex(input)
String staticSaltHashBase64 = staticSaltHasher.hashBase64(input)
String arbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8))
String differentArbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, DIFFERENT_STATIC_SALT)
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
}
@Test
void testShouldValidateArbitrarySalt() {
// Arrange
int n = 1024
int r = 8
int p = 2
int dkLength = 32
def input = "This is a sensitive value"
byte[] inputBytes = input.bytes
// Static salt instance
ScryptSecureHasher secureHasher = new ScryptSecureHasher(n, r, p, dkLength, 16)
final byte[] STATIC_SALT = "bad_sal".bytes
assertThrows(IllegalArgumentException.class, { -> new ScryptSecureHasher(n, r, p, dkLength, 7) })
assertThrows(RuntimeException.class, { -> secureHasher.hashRaw(inputBytes, STATIC_SALT) })
assertThrows(RuntimeException.class, { -> secureHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8)) })
assertThrows(RuntimeException.class, { -> secureHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8)) })
}
@Test
void testShouldFormatHex() {
// Arrange
String input = "This is a sensitive value"
final String EXPECTED_HASH_HEX = "6a9c827815fe0718af5e336811fc78dd719c8d9505e015283239b9bf1d24ee71"
SecureHasher scryptSH = new ScryptSecureHasher()
// Act
String hashHex = scryptSH.hashHex(input)
// Assert
assertEquals(EXPECTED_HASH_HEX, hashHex)
}
@Test
void testShouldFormatBase64() {
// Arrange
String input = "This is a sensitive value"
final String EXPECTED_HASH_BASE64 = "apyCeBX+BxivXjNoEfx43XGcjZUF4BUoMjm5vx0k7nE"
SecureHasher scryptSH = new ScryptSecureHasher()
// Act
String hashB64 = scryptSH.hashBase64(input)
// Assert
assertEquals(EXPECTED_HASH_BASE64, hashB64)
}
@Test
void testShouldHandleNullInput() {
// Arrange
List<String> inputs = [null, ""]
final String EXPECTED_HASH_HEX = ""
final String EXPECTED_HASH_BASE64 = ""
ScryptSecureHasher scryptSH = new ScryptSecureHasher()
def hexResults = []
def B64Results = []
// Act
inputs.each { String input ->
String hashHex = scryptSH.hashHex(input)
hexResults << hashHex
String hashB64 = scryptSH.hashBase64(input)
B64Results << hashB64
}
// Assert
hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result))
}
/**
* This test can have the minimum time threshold updated to determine if the performance
* is still sufficient compared to the existing threat model.
*/
@EnabledIfSystemProperty(named = "nifi.test.performance", matches = "true")
@Test
void testDefaultCostParamsShouldBeSufficient() {
// Arrange
int testIterations = 100
byte[] inputBytes = "This is a sensitive value".bytes
ScryptSecureHasher scryptSH = new ScryptSecureHasher()
def results = []
def resultDurations = []
// Act
testIterations.times { int i ->
long startNanos = System.nanoTime()
byte[] hash = scryptSH.hashRaw(inputBytes)
long endNanos = System.nanoTime()
long durationNanos = endNanos - startNanos
String hashHex = Hex.encode(hash)
results << hashHex
resultDurations << durationNanos
}
// Assert
final long MIN_DURATION_NANOS = 75_000_000 // 75 ms
assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
}
@Test
void testShouldVerifyRBoundary() throws Exception {
// Arrange
final int r = 32
// Act
boolean valid = ScryptSecureHasher.isRValid(r)
// Assert
assertTrue(valid)
}
@Test
void testShouldFailRBoundary() throws Exception {
// Arrange
List<Integer> rValues = [-8, 0, 2147483647]
// Act and Assert
rValues.forEach(rValue -> assertFalse(ScryptSecureHasher.isRValid(rValue)))
}
@Test
void testShouldVerifyNBoundary() throws Exception {
// Arrange
final Integer n = 16385
final int r = 8
// Act and Assert
assertTrue(ScryptSecureHasher.isNValid(n, r))
}
@Test
void testShouldFailNBoundary() throws Exception {
// Arrange
Map<Integer, Integer> costParameters = [(-8): 8, 0: 32]
//Act and Assert
costParameters.entrySet().forEach(entry -> {
assertFalse(ScryptSecureHasher.isNValid(entry.getKey(), entry.getValue()))
})
}
@Test
void testShouldVerifyPBoundary() throws Exception {
// Arrange
final List<Integer> ps = [1, 8, 1024]
final List<Integer> rs = [8, 1024, 4096]
// Act and Assert
ps.forEach(p -> {
rs.forEach(r -> {
assertTrue(ScryptSecureHasher.isPValid(p, r))
})
})
}
@Test
void testShouldFailIfPBoundaryExceeded() throws Exception {
// Arrange
final List<Integer> ps = [4096 * 64, 1024 * 1024]
final List<Integer> rs = [4096, 1024 * 1024]
// Act and Assert
ps.forEach(p -> {
rs.forEach(r -> {
assertFalse(ScryptSecureHasher.isPValid(p, r))
})
})
}
@Test
void testShouldVerifyDKLengthBoundary() throws Exception {
// Arrange
final Integer dkLength = 64
// Act
boolean valid = ScryptSecureHasher.isDKLengthValid(dkLength)
// Assert
assertTrue(valid)
}
@Test
void testShouldFailDKLengthBoundary() throws Exception {
// Arrange
def dKLengths = [-8, 0, 2147483647]
// Act and Assert
dKLengths.forEach( dKLength -> {
assertFalse(ScryptSecureHasher.isDKLengthValid(dKLength))
})
}
@Test
void testShouldVerifySaltLengthBoundary() throws Exception {
// Arrange
def saltLengths = [0, 64]
// Act and Assert
ScryptSecureHasher scryptSecureHasher = new ScryptSecureHasher()
saltLengths.forEach(saltLength -> {
assertTrue(scryptSecureHasher.isSaltLengthValid(saltLength))
})
}
@Test
void testShouldFailSaltLengthBoundary() throws Exception {
// Arrange
def saltLengths = [-8, 1, 2147483647]
// Act and Assert
ScryptSecureHasher scryptSecureHasher = new ScryptSecureHasher()
saltLengths.forEach(saltLength -> {
assertFalse(scryptSecureHasher.isSaltLengthValid(saltLength))
})
}
}

View File

@ -1,437 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License") you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.scrypt
import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.crypto.scrypt.Scrypt
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.security.SecureRandom
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assumptions.assumeTrue
class ScryptGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(ScryptGroovyTest.class)
private static final String PASSWORD = "shortPassword"
private static final String SALT_HEX = "0123456789ABCDEFFEDCBA9876543210"
private static final byte[] SALT_BYTES = Hex.decodeHex(SALT_HEX as char[])
// Small values to test for correctness, not timing
private static final int N = 2**4
private static final int R = 1
private static final int P = 1
private static final int DK_LEN = 128
private static final long TWO_GIGABYTES = 2048L * 1024 * 1024
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Test
void testDeriveScryptKeyShouldBeInternallyConsistent() throws Exception {
// Arrange
def allKeys = []
final int RUNS = 10
logger.info("Running with '${PASSWORD}', '${SALT_HEX}', $N, $R, $P, $DK_LEN")
// Act
RUNS.times {
byte[] keyBytes = Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, R, P, DK_LEN)
logger.info("Derived key: ${Hex.encodeHexString(keyBytes)}")
allKeys << keyBytes
}
// Assert
assertEquals(RUNS, allKeys.size())
allKeys.forEach(key -> assertArrayEquals(allKeys.first(), key))
}
/**
* This test ensures that the local implementation of Scrypt is compatible with the reference implementation from the Colin Percival paper.
*/
@Test
void testDeriveScryptKeyShouldMatchTestVectors() {
// Arrange
// These values are taken from Colin Percival's scrypt paper: https://www.tarsnap.com/scrypt/scrypt.pdf
final byte[] HASH_2 = Hex.decodeHex("fdbabe1c9d3472007856e7190d01e9fe" +
"7c6ad7cbc8237830e77376634b373162" +
"2eaf30d92e22a3886ff109279d9830da" +
"c727afb94a83ee6d8360cbdfa2cc0640" as char[])
final byte[] HASH_3 = Hex.decodeHex("7023bdcb3afd7348461c06cd81fd38eb" +
"fda8fbba904f8e3ea9b543f6545da1f2" +
"d5432955613f0fcf62d49705242a9af9" +
"e61e85dc0d651e40dfcf017b45575887" as char[])
final def TEST_VECTORS = [
// Empty password is not supported by JCE
[password: "password",
salt : "NaCl",
n : 1024,
r : 8,
p : 16,
dkLen : 64 * 8,
hash : HASH_2],
[password: "pleaseletmein",
salt : "SodiumChloride",
n : 16384,
r : 8,
p : 1,
dkLen : 64 * 8,
hash : HASH_3],
]
// Act
TEST_VECTORS.each { Map params ->
logger.info("Running with '${params.password}', '${params.salt}', ${params.n}, ${params.r}, ${params.p}, ${params.dkLen}")
long memoryInBytes = Scrypt.calculateExpectedMemory(params.n, params.r, params.p)
logger.info("Expected memory usage: (128 * r * N + 128 * r * p) ${memoryInBytes} bytes")
logger.info(" Expected ${Hex.encodeHexString(params.hash)}")
byte[] calculatedHash = Scrypt.deriveScryptKey(params.password.bytes, params.salt.bytes, params.n, params.r, params.p, params.dkLen)
logger.info("Generated ${Hex.encodeHexString(calculatedHash)}")
// Assert
assertArrayEquals(params.hash, calculatedHash)
}
}
/**
* This test ensures that the local implementation of Scrypt is compatible with the reference implementation from the Colin Percival paper. The test vector requires ~1GB {@code byte[]}
* and therefore the Java heap must be at least 1GB. Because {@link nifi/pom.xml} has a {@code surefire} rule which appends {@code -Xmx1G}
* to the Java options, this overrides any IDE options. To ensure the heap is properly set, using the {@code groovyUnitTest} profile will re-append {@code -Xmx3072m} to the Java options.
*/
@Test
void testDeriveScryptKeyShouldMatchExpensiveTestVector() {
// Arrange
long totalMemory = Runtime.getRuntime().totalMemory()
logger.info("Required memory: ${TWO_GIGABYTES} bytes")
logger.info("Max heap memory: ${totalMemory} bytes")
assumeTrue(totalMemory >= TWO_GIGABYTES, "Test is being skipped due to JVM heap size. Please run with -Xmx3072m to set sufficient heap size")
// These values are taken from Colin Percival's scrypt paper: https://www.tarsnap.com/scrypt/scrypt.pdf
final byte[] HASH = Hex.decodeHex("2101cb9b6a511aaeaddbbe09cf70f881" +
"ec568d574a2ffd4dabe5ee9820adaa47" +
"8e56fd8f4ba5d09ffa1c6d927c40f4c3" +
"37304049e8a952fbcbf45c6fa77a41a4" as char[])
// This test vector requires 2GB heap space and approximately 10 seconds on a consumer machine
String password = "pleaseletmein"
String salt = "SodiumChloride"
int n = 1048576
int r = 8
int p = 1
int dkLen = 64 * 8
// Act
logger.info("Running with '${password}', '${salt}', ${n}, ${r}, ${p}, ${dkLen}")
long memoryInBytes = Scrypt.calculateExpectedMemory(n, r, p)
logger.info("Expected memory usage: (128 * r * N + 128 * r * p) ${memoryInBytes} bytes")
logger.info(" Expected ${Hex.encodeHexString(HASH)}")
byte[] calculatedHash = Scrypt.deriveScryptKey(password.bytes, salt.bytes, n, r, p, dkLen)
logger.info("Generated ${Hex.encodeHexString(calculatedHash)}")
// Assert
assertArrayEquals(HASH, calculatedHash)
}
@EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true")
@Test
void testShouldCauseOutOfMemoryError() {
SecureRandom secureRandom = new SecureRandom()
// int i = 29
(10..31).each { int i ->
int length = 2**i
byte[] bytes = new byte[length]
secureRandom.nextBytes(bytes)
logger.info("Successfully ran with byte[] of length ${length}")
logger.info("${Hex.encodeHexString(bytes[0..<16] as byte[])}...")
}
}
@Test
void testDeriveScryptKeyShouldSupportExternalCompatibility() {
// Arrange
// These values can be generated by running `$ ./openssl_scrypt.rb` in the terminal
final String EXPECTED_KEY_HEX = "a8efbc0a709d3f89b6bb35b05fc8edf5"
String password = "thisIsABadPassword"
String saltHex = "f5b8056ea6e66edb8d013ac432aba24a"
int n = 1024
int r = 8
int p = 36
int dkLen = 16 * 8
// Act
logger.info("Running with '${password}', ${saltHex}, ${n}, ${r}, ${p}, ${dkLen}")
long memoryInBytes = Scrypt.calculateExpectedMemory(n, r, p)
logger.info("Expected memory usage: (128 * r * N + 128 * r * p) ${memoryInBytes} bytes")
logger.info(" Expected ${EXPECTED_KEY_HEX}")
byte[] calculatedHash = Scrypt.deriveScryptKey(password.bytes, Hex.decodeHex(saltHex as char[]), n, r, p, dkLen)
logger.info("Generated ${Hex.encodeHexString(calculatedHash)}")
// Assert
assertArrayEquals(Hex.decodeHex(EXPECTED_KEY_HEX as char[]), calculatedHash)
}
@Test
void testScryptShouldBeInternallyConsistent() throws Exception {
// Arrange
def allHashes = []
final int RUNS = 10
logger.info("Running with '${PASSWORD}', '${SALT_HEX}', $N, $R, $P")
// Act
RUNS.times {
String hash = Scrypt.scrypt(PASSWORD, SALT_BYTES, N, R, P, DK_LEN)
logger.info("Hash: ${hash}")
allHashes << hash
}
// Assert
assertEquals(RUNS, allHashes.size())
allHashes.forEach(hash -> assertEquals(allHashes.first(), hash))
}
@Test
void testScryptShouldGenerateValidSaltIfMissing() {
// Arrange
// The generated salt should be byte[16], encoded as 22 Base64 chars
final EXPECTED_SALT_PATTERN = /\$.+\$[0-9a-zA-Z\/\+]{22}\$.+/
// Act
String calculatedHash = Scrypt.scrypt(PASSWORD, N, R, P, DK_LEN)
logger.info("Generated ${calculatedHash}")
// Assert
assertTrue((calculatedHash =~ EXPECTED_SALT_PATTERN).matches())
}
@Test
void testScryptShouldNotAcceptInvalidN() throws Exception {
// Arrange
final int MAX_N = Integer.MAX_VALUE / 128 / R - 1
// N must be a power of 2 > 1 and < Integer.MAX_VALUE / 128 / r
final def INVALID_NS = [-2, 0, 1, 3, 4096 - 1, MAX_N + 1]
// Act
INVALID_NS.each { int invalidN ->
logger.info("Using N: ${invalidN}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, invalidN, R, P, DK_LEN))
// Assert
assertTrue((iae.getMessage() =~ "N must be a power of 2 greater than 1|Parameter N is too large").matches())
}
}
@Test
void testScryptShouldAcceptValidR() throws Exception {
// Arrange
// Use a large p value to allow r to exceed MAX_R without normal N exceeding MAX_N
int largeP = 2**10
final int MAX_R = Math.ceil(Integer.MAX_VALUE / 128 / largeP) - 1
// r must be in (0..Integer.MAX_VALUE / 128 / p)
final def INVALID_RS = [0, MAX_R + 1]
// Act
INVALID_RS.each { int invalidR ->
logger.info("Using r: ${invalidR}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, invalidR, largeP, DK_LEN))
// Assert
assertTrue((iae.getMessage() =~ "Parameter r must be 1 or greater|Parameter r is too large").matches())
}
}
@Test
void testScryptShouldNotAcceptInvalidP() throws Exception {
// Arrange
final int MAX_P = Math.ceil(Integer.MAX_VALUE / 128) - 1
// p must be in (0..Integer.MAX_VALUE / 128)
final def INVALID_PS = [0, MAX_P + 1]
// Act
INVALID_PS.each { int invalidP ->
logger.info("Using p: ${invalidP}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, R, invalidP, DK_LEN))
// Assert
assertTrue((iae.getMessage() =~ "Parameter p must be 1 or greater|Parameter p is too large").matches())
}
}
@Test
void testCheckShouldValidateCorrectPassword() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword"
final String EXPECTED_HASH = Scrypt.scrypt(PASSWORD, N, R, P, DK_LEN)
logger.info("Password: ${PASSWORD} -> Hash: ${EXPECTED_HASH}")
// Act
boolean matches = Scrypt.check(PASSWORD, EXPECTED_HASH)
logger.info("Check matches: ${matches}")
// Assert
assertTrue(matches)
}
@Test
void testCheckShouldNotValidateIncorrectPassword() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword"
final String EXPECTED_HASH = Scrypt.scrypt(PASSWORD, N, R, P, DK_LEN)
logger.info("Password: ${PASSWORD} -> Hash: ${EXPECTED_HASH}")
// Act
boolean matches = Scrypt.check(PASSWORD.reverse(), EXPECTED_HASH)
logger.info("Check matches: ${matches}")
// Assert
assertFalse(matches)
}
@Test
void testCheckShouldNotAcceptInvalidPassword() throws Exception {
// Arrange
final String HASH = '$s0$a0801$abcdefghijklmnopqrstuv$abcdefghijklmnopqrstuv'
// Even though the spec allows for empty passwords, the JCE does not, so extend enforcement of that to the user boundary
final def INVALID_PASSWORDS = ['', null]
// Act
INVALID_PASSWORDS.each { String invalidPassword ->
logger.info("Using password: ${invalidPassword}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.check(invalidPassword, HASH))
logger.expected(iae.getMessage())
// Assert
assertTrue(iae.getMessage().contains("Password cannot be empty"))
}
}
@Test
void testCheckShouldNotAcceptInvalidHash() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword"
// Even though the spec allows for empty salts, the JCE does not, so extend enforcement of that to the user boundary
final def INVALID_HASHES = ['', null, '$s0$a0801$', '$s0$a0801$abcdefghijklmnopqrstuv$']
// Act
INVALID_HASHES.each { String invalidHash ->
logger.info("Using hash: ${invalidHash}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.check(PASSWORD, invalidHash))
logger.expected(iae.getMessage())
// Assert
assertTrue((iae.getMessage() =~ "Hash cannot be empty|Hash is not properly formatted").matches())
}
}
@Test
void testVerifyHashFormatShouldDetectValidHash() throws Exception {
// Arrange
final def VALID_HASHES = [
"\$s0\$40801\$AAAAAAAAAAAAAAAAAAAAAA\$gLSh7ChbHdOIMvZ74XGjV6qF65d9qvQ8n75FeGnM8YM",
"\$s0\$40801\$ABCDEFGHIJKLMNOPQRSTUQ\$hxU5g0eH6sRkBqcsiApI8jxvKRT+2QMCenV0GToiMQ8",
"\$s0\$40801\$eO+UUcKYL2gnpD51QCc+gnywQ7Eg9tZeLMlf0XXr2zc\$99aTTB39TJo69aZCONQmRdyWOgYsDi+1MI+8D0EgMNM",
"\$s0\$40801\$AAAAAAAAAAAAAAAAAAAAAA\$Gk7K9YmlsWbd8FS7e4RKVWnkg9vlsqYnlD593pJ71gg",
"\$s0\$40801\$ABCDEFGHIJKLMNOPQRSTUQ\$Ri78VZbrp2cCVmGh2a9Nbfdov8LPnFb49MYyzPCaXmE",
"\$s0\$40801\$eO+UUcKYL2gnpD51QCc+gnywQ7Eg9tZeLMlf0XXr2zc\$rZIrP2qdIY7LN4CZAMgbCzl3YhXz6WhaNyXJXqFIjaI",
"\$s0\$40801\$AAAAAAAAAAAAAAAAAAAAAA\$GxH68bGykmPDZ6gaPIGOONOT2omlZ7cd0xlcZ9UsY/0",
"\$s0\$40801\$ABCDEFGHIJKLMNOPQRSTUQ\$KLGZjWlo59sbCbtmTg5b4k0Nu+biWZRRzhPhN7K5kkI",
"\$s0\$40801\$eO+UUcKYL2gnpD51QCc+gnywQ7Eg9tZeLMlf0XXr2zc\$6Ql6Efd2ac44ERoV31CL3Q0J3LffNZKN4elyMHux99Y",
// Uncommon but technically valid
"\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$A",
"\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP",
"\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$" +
"ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP",
"\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$F0801\$AAAAAAAAAAA\$A",
]
// Act
VALID_HASHES.each { String validHash ->
logger.info("Using hash: ${validHash}")
boolean isValidHash = Scrypt.verifyHashFormat(validHash)
logger.info("Hash is valid: ${isValidHash}")
// Assert
assertTrue(isValidHash)
}
}
@Test
void testVerifyHashFormatShouldDetectInvalidHash() throws Exception {
// Arrange
// Even though the spec allows for empty salts, the JCE does not, so extend enforcement of that to the user boundary
final def INVALID_HASHES = ['', null, '$s0$a0801$', '$s0$a0801$abcdefghijklmnopqrstuv$']
// Act
INVALID_HASHES.each { String invalidHash ->
logger.info("Using hash: ${invalidHash}")
boolean isValidHash = Scrypt.verifyHashFormat(invalidHash)
logger.info("Hash is valid: ${isValidHash}")
// Assert
assertFalse(isValidHash)
}
}
}

View File

@ -14,311 +14,286 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package org.apache.nifi.security.util.crypto package org.apache.nifi.security.util.crypto;
import org.apache.commons.codec.binary.Hex import org.apache.commons.codec.DecoderException;
import org.apache.nifi.security.util.EncryptionMethod import org.apache.commons.codec.binary.Hex;
import org.bouncycastle.jce.provider.BouncyCastleProvider import org.apache.nifi.security.util.EncryptionMethod;
import org.junit.jupiter.api.BeforeAll import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.Test import org.junit.jupiter.api.BeforeAll;
import org.slf4j.Logger import org.junit.jupiter.api.Test;
import org.slf4j.LoggerFactory
import javax.crypto.Cipher import javax.crypto.Cipher;
import javax.crypto.SecretKey import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec import javax.crypto.spec.SecretKeySpec;
import java.security.SecureRandom import java.security.SecureRandom;
import java.security.Security import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue
class AESKeyedCipherProviderGroovyTest { public class AESKeyedCipherProviderTest {
private static final Logger logger = LoggerFactory.getLogger(AESKeyedCipherProviderGroovyTest.class) private static final String KEY_HEX = "0123456789ABCDEFFEDCBA9876543210";
private static final String KEY_HEX = "0123456789ABCDEFFEDCBA9876543210" private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess";
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess" private static final List<EncryptionMethod> keyedEncryptionMethods = Arrays.stream(EncryptionMethod.values())
.filter(EncryptionMethod::isKeyedCipher)
.collect(Collectors.toList());
private static final List<EncryptionMethod> keyedEncryptionMethods = EncryptionMethod.values().findAll { it.keyedCipher } private static SecretKey key;
private static final SecretKey key = new SecretKeySpec(Hex.decodeHex(KEY_HEX as char[]), "AES")
@BeforeAll @BeforeAll
static void setUpOnce() throws Exception { static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider()) Security.addProvider(new BouncyCastleProvider());
logger.metaClass.methodMissing = { String name, args -> try {
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") key = new SecretKeySpec(Hex.decodeHex(KEY_HEX.toCharArray()), "AES");
} catch (final DecoderException e) {
throw new RuntimeException(e);
} }
} }
private static boolean isUnlimitedStrengthCryptoAvailable() {
Cipher.getMaxAllowedKeyLength("AES") > 128
}
@Test @Test
void testGetCipherShouldBeInternallyConsistent() throws Exception { void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange // Arrange
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
// Act // Act
for (EncryptionMethod em : keyedEncryptionMethods) { for (EncryptionMethod em : keyedEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption // Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, key, true) Cipher cipher = cipherProvider.getCipher(em, key, true);
byte[] iv = cipher.getIV() byte[] iv = cipher.getIV();
logger.info("IV: ${Hex.encodeHexString(iv)}")
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8")) byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, key, iv, false) cipher = cipherProvider.getCipher(em, key, iv, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes) byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8");
logger.info("Recovered: ${recovered}")
// Assert // Assert
assertEquals(PLAINTEXT, recovered) assertEquals(PLAINTEXT, recovered);
} }
} }
@Test @Test
void testGetCipherWithExternalIVShouldBeInternallyConsistent() throws Exception { void testGetCipherWithExternalIVShouldBeInternallyConsistent() throws Exception {
// Arrange // Arrange
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
// Act // Act
keyedEncryptionMethods.each { EncryptionMethod em -> for (final EncryptionMethod em : keyedEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}") byte[] iv = cipherProvider.generateIV();
byte[] iv = cipherProvider.generateIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
// Initialize a cipher for encryption // Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, key, iv, true) Cipher cipher = cipherProvider.getCipher(em, key, iv, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8")) byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, key, iv, false) cipher = cipherProvider.getCipher(em, key, iv, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes) byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8");
logger.info("Recovered: ${recovered}")
// Assert // Assert
assertEquals(PLAINTEXT, recovered) assertEquals(PLAINTEXT, recovered);
} }
} }
@Test @Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception { void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
// Arrange // Arrange
assumeTrue(isUnlimitedStrengthCryptoAvailable(), "Test is being skipped due to this JVM lacking JCE Unlimited Strength Jurisdiction Policy file.") KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
final List<Integer> longKeyLengths = Arrays.asList(192, 256);
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() SecureRandom secureRandom = new SecureRandom();
final List<Integer> LONG_KEY_LENGTHS = [192, 256]
SecureRandom secureRandom = new SecureRandom()
// Act // Act
keyedEncryptionMethods.each { EncryptionMethod em -> for (final EncryptionMethod em : keyedEncryptionMethods) {
// Re-use the same IV for the different length keys to ensure the encryption is different // Re-use the same IV for the different length keys to ensure the encryption is different
byte[] iv = cipherProvider.generateIV() byte[] iv = cipherProvider.generateIV();
logger.info("IV: ${Hex.encodeHexString(iv)}")
LONG_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${em.getAlgorithm()} with key length ${keyLength}")
for (final int keyLength: longKeyLengths) {
// Generate a key // Generate a key
byte[] keyBytes = new byte[keyLength / 8] byte[] keyBytes = new byte[keyLength / 8];
secureRandom.nextBytes(keyBytes) secureRandom.nextBytes(keyBytes);
SecretKey localKey = new SecretKeySpec(keyBytes, "AES") SecretKey localKey = new SecretKeySpec(keyBytes, "AES");
logger.info("Key: ${Hex.encodeHexString(keyBytes)} ${keyBytes.length}")
// Initialize a cipher for encryption // Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, localKey, iv, true) Cipher cipher = cipherProvider.getCipher(em, localKey, iv, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8")) byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
cipher = cipherProvider.getCipher(em, localKey, iv, false) cipher = cipherProvider.getCipher(em, localKey, iv, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes) byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8");
logger.info("Recovered: ${recovered}")
// Assert // Assert
assertEquals(PLAINTEXT, recovered) assertEquals(PLAINTEXT, recovered);
} }
} }
} }
@Test @Test
void testShouldRejectEmptyKey() throws Exception { void testShouldRejectEmptyKey() {
// Arrange // Arrange
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act // Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, null, true)) () -> cipherProvider.getCipher(encryptionMethod, null, true));
// Assert // Assert
assertTrue(iae.message.contains("The key must be specified")) assertTrue(iae.getMessage().contains("The key must be specified"));
} }
@Test @Test
void testShouldRejectIncorrectLengthKey() throws Exception { void testShouldRejectIncorrectLengthKey() throws Exception {
// Arrange // Arrange
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
SecretKey localKey = new SecretKeySpec(Hex.decodeHex("0123456789ABCDEF" as char[]), "AES") SecretKey localKey = new SecretKeySpec(Hex.decodeHex("0123456789ABCDEF".toCharArray()), "AES");
assertFalse([128, 192, 256].contains(localKey.encoded.length)) assertFalse(Arrays.asList(128, 192, 256).contains(localKey.getEncoded().length));
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act // Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, localKey, true)) () -> cipherProvider.getCipher(encryptionMethod, localKey, true));
// Assert // Assert
assertTrue(iae.message.contains("The key must be of length [128, 192, 256]")) assertTrue(iae.getMessage().contains("The key must be of length [128, 192, 256]"));
} }
@Test @Test
void testShouldRejectEmptyEncryptionMethod() throws Exception { void testShouldRejectEmptyEncryptionMethod() throws Exception {
// Arrange // Arrange
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
// Act // Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(null, key, true)) () -> cipherProvider.getCipher(null, key, true));
// Assert // Assert
assertTrue(iae.message.contains("The encryption method must be specified")) assertTrue(iae.getMessage().contains("The encryption method must be specified"));
} }
@Test @Test
void testShouldRejectUnsupportedEncryptionMethod() throws Exception { void testShouldRejectUnsupportedEncryptionMethod() throws Exception {
// Arrange // Arrange
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
final EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES final EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES;
// Act // Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, key, true)) () -> cipherProvider.getCipher(encryptionMethod, key, true));
// Assert // Assert
assertTrue(iae.message.contains("requires a PBECipherProvider")) assertTrue(iae.getMessage().contains("requires a PBECipherProvider"));
} }
@Test @Test
void testGetCipherShouldSupportExternalCompatibility() throws Exception { void testGetCipherShouldSupportExternalCompatibility() throws Exception {
// Arrange // Arrange
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
final String plaintext = "This is a plaintext message." final String plaintext = "This is a plaintext message.";
// These values can be generated by running `$ ./openssl_aes.rb` in the terminal // These values can be generated by running `$ ./openssl_aes.rb` in the terminal
final byte[] IV = Hex.decodeHex("e0bc8cc7fbc0bdfdc184dc22ce2fcb5b" as char[]) final byte[] IV = Hex.decodeHex("e0bc8cc7fbc0bdfdc184dc22ce2fcb5b".toCharArray());
final byte[] LOCAL_KEY = Hex.decodeHex("c72943d27c3e5a276169c5998a779117" as char[]) final byte[] LOCAL_KEY = Hex.decodeHex("c72943d27c3e5a276169c5998a779117".toCharArray());
final String CIPHER_TEXT = "a2725ea55c7dd717664d044cab0f0b5f763653e322c27df21954f5be394efb1b" final String CIPHER_TEXT = "a2725ea55c7dd717664d044cab0f0b5f763653e322c27df21954f5be394efb1b";
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT as char[]) byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT.toCharArray());
SecretKey localKey = new SecretKeySpec(LOCAL_KEY, "AES") SecretKey localKey = new SecretKeySpec(LOCAL_KEY, "AES");
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
// Act // Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, localKey, IV, false) Cipher cipher = cipherProvider.getCipher(encryptionMethod, localKey, IV, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes) byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8");
logger.info("Recovered: ${recovered}")
// Assert // Assert
assertEquals(plaintext, recovered) assertEquals(plaintext, recovered);
} }
@Test @Test
void testGetCipherForDecryptShouldRequireIV() throws Exception { void testGetCipherForDecryptShouldRequireIV() throws Exception {
// Arrange // Arrange
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
// Act // Act
keyedEncryptionMethods.each { EncryptionMethod em -> for (final EncryptionMethod em : keyedEncryptionMethods) {
logger.info("Using algorithm: ${em.getAlgorithm()}") byte[] iv = cipherProvider.generateIV();
byte[] iv = cipherProvider.generateIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
// Initialize a cipher for encryption // Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, key, iv, true) Cipher cipher = cipherProvider.getCipher(em, key, iv, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8")) byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, key, false)) () -> cipherProvider.getCipher(em, key, false));
// Assert // Assert
assertTrue(iae.message.contains("Cannot decrypt without a valid IV")) assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"));
} }
} }
@Test @Test
void testGetCipherShouldRejectInvalidIVLengths() throws Exception { void testGetCipherShouldRejectInvalidIVLengths() throws Exception {
// Arrange // Arrange
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
final def INVALID_IVS = (0..15).collect { int length -> new byte[length] } final int MAX_LENGTH = 15;
final List<byte[]> INVALID_IVS = new ArrayList<>();
for (int length = 0; length <= MAX_LENGTH; length++) {
INVALID_IVS.add(new byte[length]);
}
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act // Act
INVALID_IVS.each { byte[] badIV -> for (final byte[] badIV : INVALID_IVS) {
logger.info("IV: ${Hex.encodeHexString(badIV)} ${badIV.length}")
// Encrypt should print a warning about the bad IV but overwrite it // Encrypt should print a warning about the bad IV but overwrite it
Cipher cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, true) Cipher cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, true);
// Decrypt should fail // Decrypt should fail
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, key, badIV, false)) () -> cipherProvider.getCipher(encryptionMethod, key, badIV, false));
logger.warn(iae.getMessage())
// Assert // Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV")) assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"));
} }
} }
@Test @Test
void testGetCipherShouldRejectEmptyIV() throws Exception { void testGetCipherShouldRejectEmptyIV() throws Exception {
// Arrange // Arrange
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider();
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
byte[] badIV = [0x00 as byte] * 16 as byte[] byte[] badIV = new byte[16];
Arrays.fill(badIV, (byte) '\0');
// Act
logger.info("IV: ${Hex.encodeHexString(badIV)} ${badIV.length}")
// Encrypt should print a warning about the bad IV but overwrite it // Encrypt should print a warning about the bad IV but overwrite it
Cipher cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, true) Cipher cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, true);
logger.info("IV after encrypt: ${Hex.encodeHexString(cipher.getIV())}")
// Decrypt should fail // Decrypt should fail
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, key, badIV, false)) () -> cipherProvider.getCipher(encryptionMethod, key, badIV, false));
logger.warn(iae.getMessage())
// Assert // Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV")) assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"));
} }
} }

View File

@ -0,0 +1,407 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.binary.Hex;
import org.apache.nifi.security.util.EncryptionMethod;
import org.apache.nifi.util.StringUtils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class Argon2CipherProviderTest {
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess";
private static final String BAD_PASSWORD = "thisIsABadPassword";
private static final String SHORT_PASSWORD = "shortPassword";
private static final int DEFAULT_KEY_LENGTH = 128;
private final String SALT_HEX = "0123456789ABCDEFFEDCBA9876543210";
private final List<Integer> FULL_SALT_LENGTH_RANGE = Arrays.asList(49, 50, 51, 52, 53);
private static List<Integer> VALID_KEY_LENGTHS;
private static List<EncryptionMethod> strongKDFEncryptionMethods;
private RandomIVPBECipherProvider cipherProvider;
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider());
strongKDFEncryptionMethods = Arrays.stream(EncryptionMethod.values())
.filter(EncryptionMethod::isCompatibleWithStrongKDFs)
.collect(Collectors.toList());
VALID_KEY_LENGTHS = Arrays.asList(128, 192, 256);
}
@BeforeEach
void setUp() {
// Very fast parameters to test for correctness rather than production values
cipherProvider = new Argon2CipherProvider(1024, 1, 3);
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, DEFAULT_KEY_LENGTH, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testArgon2ShouldSupportExternalCompatibility() throws Exception {
// Arrange
// Default values are hashLength = 32, memory = 1024, parallelism = 1, iterations = 3, but the provided salt will contain the parameters used
cipherProvider = new Argon2CipherProvider();
final String PLAINTEXT = "This is a plaintext message.";
final int hashLength = 256;
// These values can be generated by running `$ ./openssl_argon2.rb` in the terminal
final byte[] SALT = Hex.decodeHex("68d29a1d8021f45954333767358a2492".toCharArray());
final byte[] IV = Hex.decodeHex("808590f35f9fba14dbda9c2bb2b76a79".toCharArray());
final String CIPHER_TEXT = "d672412857916880c79d573aa4f9d4971b85f07438d6f62f38a0e31314caa2e5";
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT.toCharArray());
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Sanity check
String rubyKeyHex = "8caf581795886d38f0c605e3d674f4961c658ee3625a8e8868be36c902d234ef";
Cipher rubyCipher = Cipher.getInstance(encryptionMethod.getAlgorithm(), "BC");
SecretKeySpec rubyKey = new SecretKeySpec(Hex.decodeHex(rubyKeyHex.toCharArray()), "AES");
IvParameterSpec ivSpec = new IvParameterSpec(IV);
rubyCipher.init(Cipher.ENCRYPT_MODE, rubyKey, ivSpec);
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.getBytes());
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec);
assertArrayEquals(PLAINTEXT.getBytes(), rubyCipher.doFinal(rubyCipherBytes));
assertArrayEquals(PLAINTEXT.getBytes(), rubyCipher.doFinal(cipherBytes));
// $argon2id$v=19$m=memory,t=iterations,p=parallelism$saltB64$hashB64
final String FULL_HASH = "$argon2id$v=19$m=256,t=3,p=1$aNKaHYAh9FlUMzdnNYokkg$jK9YF5WIbTjwxgXj1nT0lhxljuNiWo6IaL42yQLSNO8";
final String FULL_SALT = FULL_HASH.substring(0, FULL_HASH.lastIndexOf("$"));
final String[] hashComponents = FULL_HASH.split("\\$");
final String saltB64 = hashComponents[4];
byte[] salt = Base64.decodeBase64(saltB64);
assertArrayEquals(SALT, salt);
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, FULL_SALT.getBytes(), IV, hashLength, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
@Test
void testGetCipherShouldRejectInvalidIV() throws Exception {
// Arrange
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
final int MAX_LENGTH = 15;
final List<byte[]> INVALID_IVS = new ArrayList<>();
for (int length = 0; length <= MAX_LENGTH; length++) {
INVALID_IVS.add(new byte[length]);
}
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final byte[] badIV: INVALID_IVS) {
// Encrypt should print a warning about the bad IV but overwrite it
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, true);
// Decrypt should fail
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false));
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"));
}
}
@Test
void testGetCipherWithExternalIVShouldBeInternallyConsistent() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("01", 16).toCharArray());
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
final int LONG_KEY_LENGTH = 256;
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, LONG_KEY_LENGTH, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, iv, LONG_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherShouldNotAcceptInvalidSalts() throws Exception {
// Arrange
final List<String> INVALID_SALTS = Arrays.asList("argon2", "$3a$11$", "x", "$2a$10$");
final String LENGTH_MESSAGE = "The raw salt must be greater than or equal to 8 bytes";
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final String salt : INVALID_SALTS) {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, salt.getBytes(), DEFAULT_KEY_LENGTH, true));
// Assert
assertTrue(iae.getMessage().contains(LENGTH_MESSAGE));
}
}
@Test
void testGetCipherShouldHandleUnformattedSalts() throws Exception {
// Arrange
final List<String> RECOVERABLE_SALTS = Arrays.asList("$ab$00$acbdefghijklmnopqrstuv", "$4$1$1$0123456789abcdef", "$400$1$1$abcdefghijklmnopqrstuv");
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final String salt : RECOVERABLE_SALTS) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, salt.getBytes(), DEFAULT_KEY_LENGTH, true);
// Assert
assertNotNull(cipher);
}
}
@Test
void testGetCipherShouldRejectEmptySalt() throws Exception {
// Arrange
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true));
// Assert
assertTrue(iae.getMessage().contains("The salt cannot be empty. To generate a salt, use Argon2CipherProvider#generateSalt()"));
}
@Test
void testGenerateSaltShouldProvideValidSalt() throws Exception {
// Arrange
Argon2CipherProvider cipherProvider = new Argon2CipherProvider();
// Act
byte[] saltBytes = cipherProvider.generateSalt();
String fullSalt = new String(saltBytes, StandardCharsets.UTF_8);
final Matcher matcher = Pattern.compile("\\$([\\w\\+\\/]+)\\$?$").matcher(fullSalt);
matcher.find();
final String rawSaltB64 = matcher.group(1);
byte[] rawSaltBytes = Base64.decodeBase64(rawSaltB64);
// Assert
boolean isValidFormattedSalt = cipherProvider.isArgon2FormattedSalt(fullSalt);
assertTrue(isValidFormattedSalt);
boolean fullSaltIsValidLength = FULL_SALT_LENGTH_RANGE.contains(saltBytes.length);
assertTrue(fullSaltIsValidLength);
byte[] notExpected = new byte[16];
Arrays.fill(notExpected, (byte) '\0');
assertFalse(Arrays.equals(notExpected, rawSaltBytes));
}
@Test
void testGetCipherForDecryptShouldRequireIV() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("00", 16).toCharArray());
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, DEFAULT_KEY_LENGTH, false));
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"));
}
}
@Test
void testGetCipherShouldAcceptValidKeyLengths() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("01", 16).toCharArray());
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final int keyLength : VALID_KEY_LENGTHS) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherShouldNotAcceptInvalidKeyLengths() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("00", 16).toCharArray());
final List<Integer> INVALID_KEY_LENGTHS = Arrays.asList(-1, 40, 64, 112, 512);
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final int keyLength : INVALID_KEY_LENGTHS) {
// Initialize a cipher for
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, true));
// Assert
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"));
}
}
@Test
void testArgon2ShouldNotAcceptInvalidPassword() {
// Arrange
String badPassword = "";
byte[] salt = new byte[16];
Arrays.fill(salt, (byte) 0x01);
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true));
// Assert
assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"));
}
@Test
void testShouldRejectInvalidSalt() throws Exception {
// Arrange
Argon2CipherProvider cipherProvider = new Argon2CipherProvider();
final String FULL_HASH = "$argon2id$v=19$m=1024,t=4,p=1$hiKyaQbZyQBmCmD1zGcyMw$rc+ec+/hQeBcwzjH+OEmUtaTUqhZYKN4ZKJtWzFZYjQ";
// Act
boolean isValid = cipherProvider.isArgon2FormattedSalt(FULL_HASH);
// Assert
assertFalse(isValid);
}
@Test
void testShouldExtractSalt() throws Exception {
// Arrange
Argon2CipherProvider cipherProvider = new Argon2CipherProvider();
final byte[] EXPECTED_RAW_SALT = Hex.decodeHex("8622b26906d9c900660a60f5cc673233".toCharArray());
final String FORMATTED_SALT = "$argon2id$v=19$m=1024,t=4,p=1$hiKyaQbZyQBmCmD1zGcyMw";
byte[] rawSalt;
// Act
rawSalt = cipherProvider.extractRawSaltFromArgon2Salt(FORMATTED_SALT);
// Assert
assertArrayEquals(EXPECTED_RAW_SALT, rawSalt);
}
}

View File

@ -0,0 +1,414 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import org.bouncycastle.util.encoders.Hex;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class Argon2SecureHasherTest {
@Test
void testShouldBeDeterministicWithStaticSalt() {
// Arrange
int hashLength = 32;
int memory = 8;
int parallelism = 4;
int iterations = 4;
int testIterations = 10;
byte[] inputBytes = "This is a sensitive value".getBytes();
final String EXPECTED_HASH_HEX = "a73a471f51b2900901a00b81e770b9c1dfc595602bb7aec64cd27754a4174919";
Argon2SecureHasher a2sh = new Argon2SecureHasher(hashLength, memory, parallelism, iterations);
final List<String> results = new ArrayList<>();
// Act
for (int i = 0; i < testIterations; i++) {
byte[] hash = a2sh.hashRaw(inputBytes);
String hashHex = new String(Hex.encode(hash));
results.add(hashHex);
}
// Assert
results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result));
}
@Test
void testShouldBeDifferentWithRandomSalt() {
// Arrange
int hashLength = 32;
int memory = 8;
int parallelism = 4;
int iterations = 4;
int testIterations = 10;
byte[] inputBytes = "This is a sensitive value".getBytes();
final String EXPECTED_HASH_HEX = "a73a471f51b2900901a00b81e770b9c1dfc595602bb7aec64cd27754a4174919";
Argon2SecureHasher a2sh = new Argon2SecureHasher(hashLength, memory, parallelism, iterations, 16);
final List<String> results = new ArrayList<>();
// Act
for (int i = 0; i < testIterations; i++) {
byte[] hash = a2sh.hashRaw(inputBytes);
String hashHex = new String(Hex.encode(hash));
results.add(hashHex);
}
// Assert
assertTrue(results.stream().distinct().collect(Collectors.toList()).size() == results.size());
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result));
}
@Test
void testShouldHandleArbitrarySalt() {
// Arrange
int hashLength = 32;
int memory = 8;
int parallelism = 4;
int iterations = 4;
final String input = "This is a sensitive value";
byte[] inputBytes = input.getBytes();
final String EXPECTED_HASH_HEX = "a73a471f51b2900901a00b81e770b9c1dfc595602bb7aec64cd27754a4174919";
final String EXPECTED_HASH_BASE64 = "pzpHH1GykAkBoAuB53C5wd/FlWArt67GTNJ3VKQXSRk";
final byte[] EXPECTED_HASH_BYTES = Hex.decode(EXPECTED_HASH_HEX);
// Static salt instance
Argon2SecureHasher staticSaltHasher = new Argon2SecureHasher(hashLength, memory, parallelism, iterations);
Argon2SecureHasher arbitrarySaltHasher = new Argon2SecureHasher(hashLength, memory, parallelism, iterations, 16);
final byte[] STATIC_SALT = "NiFi Static Salt".getBytes(StandardCharsets.UTF_8);
final String DIFFERENT_STATIC_SALT = "Diff Static Salt";
// Act
byte[] staticSaltHash = staticSaltHasher.hashRaw(inputBytes);
byte[] arbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, STATIC_SALT);
byte[] differentArbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, DIFFERENT_STATIC_SALT.getBytes(StandardCharsets.UTF_8));
byte[] differentSaltHash = arbitrarySaltHasher.hashRaw(inputBytes);
String staticSaltHashHex = staticSaltHasher.hashHex(input);
String arbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8));
String differentArbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, DIFFERENT_STATIC_SALT);
String differentSaltHashHex = arbitrarySaltHasher.hashHex(input);
String staticSaltHashBase64 = staticSaltHasher.hashBase64(input);
String arbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8));
String differentArbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, DIFFERENT_STATIC_SALT);
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input);
// Assert
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash);
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash);
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash));
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash));
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex);
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex);
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex);
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex);
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64);
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64);
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64);
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64);
}
@Test
void testShouldValidateArbitrarySalt() {
// Arrange
int hashLength = 32;
int memory = 8;
int parallelism = 4;
int iterations = 4;
final String input = "This is a sensitive value";
byte[] inputBytes = input.getBytes();
// Static salt instance
Argon2SecureHasher secureHasher = new Argon2SecureHasher(hashLength, memory, parallelism, iterations, 16);
final byte[] STATIC_SALT = "bad_sal".getBytes();
// Act
assertThrows(IllegalArgumentException.class, () ->
new Argon2SecureHasher(hashLength, memory, parallelism, iterations, 7)
);
assertThrows(RuntimeException.class, () -> secureHasher.hashRaw(inputBytes, STATIC_SALT));
assertThrows(RuntimeException.class, () -> secureHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8)));
assertThrows(RuntimeException.class, () -> secureHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8)));
}
@Test
void testShouldFormatHex() {
// Arrange
String input = "This is a sensitive value";
final String EXPECTED_HASH_HEX = "0c2920c52f28e0a2c77d006ec6138c8dc59580881468b85541cf886abdebcf18";
Argon2SecureHasher a2sh = new Argon2SecureHasher(32, 4096, 1, 3);
// Act
String hashHex = a2sh.hashHex(input);
// Assert
assertEquals(EXPECTED_HASH_HEX, hashHex);
}
@Test
void testShouldFormatBase64() {
// Arrange
String input = "This is a sensitive value";
final String EXPECTED_HASH_B64 = "DCkgxS8o4KLHfQBuxhOMjcWVgIgUaLhVQc+Iar3rzxg";
Argon2SecureHasher a2sh = new Argon2SecureHasher(32, 4096, 1, 3);
// Act
String hashB64 = a2sh.hashBase64(input);
// Assert
assertEquals(EXPECTED_HASH_B64, hashB64);
}
@Test
void testShouldHandleNullInput() {
// Arrange
List<String> inputs = Arrays.asList(null, "");
final String EXPECTED_HASH_HEX = "8e5625a66b94ed9d31c1496d7f9ff49249cf05d6753b50ba0e2bf2a1108973dd";
final String EXPECTED_HASH_B64 = "jlYlpmuU7Z0xwUltf5/0kknPBdZ1O1C6DivyoRCJc90";
Argon2SecureHasher a2sh = new Argon2SecureHasher(32, 4096, 1, 3);
final List<String> hexResults = new ArrayList<>();
final List<String> b64Results = new ArrayList<>();
// Act
for (final String input : inputs) {
String hashHex = a2sh.hashHex(input);
hexResults.add(hashHex);
String hashB64 = a2sh.hashBase64(input);
b64Results.add(hashB64);
}
// Assert
hexResults.forEach(hexResult -> assertEquals(EXPECTED_HASH_HEX, hexResult));
b64Results.forEach(b64Result -> assertEquals(EXPECTED_HASH_B64, b64Result));
}
/**
* This test can have the minimum time threshold updated to determine if the performance
* is still sufficient compared to the existing threat model.
*/
@EnabledIfSystemProperty(named = "nifi.test.performance", matches = "true")
@Test
void testDefaultCostParamsShouldBeSufficient() {
// Arrange
int testIterations = 100; //_000
byte[] inputBytes = "This is a sensitive value".getBytes();
Argon2SecureHasher a2sh = new Argon2SecureHasher(16, (int) Math.pow(2, 16), 8, 5);
final List<String> results = new ArrayList<>();
final List<Long> resultDurations = new ArrayList<>();
// Act
for (int i = 0; i < testIterations; i++) {
long startNanos = System.nanoTime();
byte[] hash = a2sh.hashRaw(inputBytes);
long endNanos = System.nanoTime();
long durationNanos = endNanos - startNanos;
String hashHex = new String(Hex.encode(hash));
results.add(hashHex);
resultDurations.add(durationNanos);
}
// Assert
final long MIN_DURATION_NANOS = 500_000_000; // 500 ms
assertTrue(Collections.min(resultDurations) > MIN_DURATION_NANOS);
assertTrue(resultDurations.stream().mapToLong(Long::longValue).sum() / testIterations > MIN_DURATION_NANOS);
}
@Test
void testShouldVerifyHashLengthBoundary() throws Exception {
// Arrange
final int hashLength = 128;
// Act
boolean valid = Argon2SecureHasher.isHashLengthValid(hashLength);
// Assert
assertTrue(valid);
}
@Test
void testShouldFailHashLengthBoundary() throws Exception {
// Arrange
final List<Integer> hashLengths = Arrays.asList(-8, 0, 1, 2);
// Act & Assert
for (final int hashLength: hashLengths) {
assertFalse(Argon2SecureHasher.isHashLengthValid(hashLength));
}
}
@Test
void testShouldVerifyMemorySizeBoundary() throws Exception {
// Arrange
final int memory = 2048;
// Act
boolean valid = Argon2SecureHasher.isMemorySizeValid(memory);
// Assert
assertTrue(valid);
}
@Test
void testShouldFailMemorySizeBoundary() throws Exception {
// Arrange
final List<Integer> memorySizes = Arrays.asList(-12, 0, 1, 6);
// Act & Assert
for (final int memory : memorySizes) {
assertFalse(Argon2SecureHasher.isMemorySizeValid(memory));
}
}
@Test
void testShouldVerifyParallelismBoundary() throws Exception {
// Arrange
final int parallelism = 4;
// Act
boolean valid = Argon2SecureHasher.isParallelismValid(parallelism);
// Assert
assertTrue(valid);
}
@Test
void testShouldFailParallelismBoundary() throws Exception {
// Arrange
final List<Integer> parallelisms = Arrays.asList(-8, 0, 16777220, 16778000);
// Act & Assert
for (final int parallelism : parallelisms) {
assertFalse(Argon2SecureHasher.isParallelismValid(parallelism));
}
}
@Test
void testShouldVerifyIterationsBoundary() throws Exception {
// Arrange
final int iterations = 4;
// Act
boolean valid = Argon2SecureHasher.isIterationsValid(iterations);
// Assert
assertTrue(valid);
}
@Test
void testShouldFailIterationsBoundary() throws Exception {
// Arrange
final List<Integer> iterationCounts = Arrays.asList(-50, -1, 0);
// Act & Assert
for (final int iterations: iterationCounts) {
assertFalse(Argon2SecureHasher.isIterationsValid(iterations));
}
}
@Test
void testShouldVerifySaltLengthBoundary() throws Exception {
// Arrange
final List<Integer> saltLengths = Arrays.asList(0, 64);
// Act and Assert
Argon2SecureHasher argon2SecureHasher = new Argon2SecureHasher();
saltLengths.forEach(saltLength ->
assertTrue(argon2SecureHasher.isSaltLengthValid(saltLength))
);
}
@Test
void testShouldFailSaltLengthBoundary() throws Exception {
// Arrange
final List<Integer> saltLengths = Arrays.asList(-16, 4);
// Act and Assert
Argon2SecureHasher argon2SecureHasher = new Argon2SecureHasher();
saltLengths.forEach(saltLength -> assertFalse(argon2SecureHasher.isSaltLengthValid(saltLength)));
}
@Test
void testShouldCreateHashOfDesiredLength() throws Exception {
// Arrange
final List<Integer> hashLengths = Arrays.asList(16, 32);
final String PASSWORD = "password";
final byte[] SALT = new byte[16];
Arrays.fill(SALT, (byte) '\0');
final byte[] EXPECTED_HASH = Hex.decode("411c9c87e7c91d8c8eacc418665bd2e1");
// Act
Map<Integer, byte[]> results = hashLengths
.stream()
.collect(
Collectors.toMap(
Function.identity(),
hashLength -> {
Argon2SecureHasher ash = new Argon2SecureHasher(hashLength, 8, 1, 3);
final byte[] hash = ash.hashRaw(PASSWORD.getBytes(), SALT);
return hash;
}
)
);
// Assert
assertFalse(Arrays.equals(Arrays.copyOf(results.get(16), 16), Arrays.copyOf(results.get(32), 16)));
// Demonstrates that internal hash truncation is not supported
// assert results.every { int k, byte[] v -> v[0..15] as byte[] == EXPECTED_HASH}
}
}

View File

@ -0,0 +1,474 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import at.favre.lib.crypto.bcrypt.BCrypt;
import at.favre.lib.crypto.bcrypt.Radix64Encoder;
import org.apache.commons.codec.binary.Hex;
import org.apache.nifi.security.util.EncryptionMethod;
import org.apache.nifi.util.StringUtils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.Security;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class BcryptCipherProviderTest {
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess";
private static final String BAD_PASSWORD = "thisIsABadPassword";
private static final String SHORT_PASSWORD = "shortPassword";
private static List<EncryptionMethod> strongKDFEncryptionMethods;
private static final int DEFAULT_KEY_LENGTH = 128;
private static List<Integer> AES_KEY_LENGTHS;
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider());
strongKDFEncryptionMethods = Arrays.stream(EncryptionMethod.values())
.filter(EncryptionMethod::isCompatibleWithStrongKDFs)
.collect(Collectors.toList());
AES_KEY_LENGTHS = Arrays.asList(128, 192, 256);
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4);
final byte[] SALT = cipherProvider.generateSalt();
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, DEFAULT_KEY_LENGTH, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherWithExternalIVShouldBeInternallyConsistent() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4);
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("01", 16).toCharArray());
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4);
final byte[] SALT = cipherProvider.generateSalt();
final int LONG_KEY_LENGTH = 256;
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, LONG_KEY_LENGTH, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, iv, LONG_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testHashPWShouldMatchTestVectors() {
// Arrange
final byte[] PASSWORD = "abcdefghijklmnopqrstuvwxyz".getBytes(StandardCharsets.UTF_8);
final byte[] SALT = new Radix64Encoder.Default().decode("fVH8e28OQRj9tqiDXs1e1u".getBytes(StandardCharsets.UTF_8));
final String EXPECTED_HASH = "$2a$10$fVH8e28OQRj9tqiDXs1e1uxpsjN0c7II7YPKXua2NAKYvM6iQk7dq";
final int WORK_FACTOR = 10;
// Act
String libraryCalculatedHash = new String(BCrypt.withDefaults().hash(WORK_FACTOR, SALT, PASSWORD), StandardCharsets.UTF_8);
BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher(WORK_FACTOR);
String secureHasherCalculatedHash = new String(bcryptSecureHasher.hashRaw(PASSWORD, SALT), StandardCharsets.UTF_8);
// Assert
assertEquals(EXPECTED_HASH, secureHasherCalculatedHash);
assertEquals(EXPECTED_HASH, secureHasherCalculatedHash);
}
@Test
void testGetCipherShouldSupportExternalCompatibility() throws Exception {
// Arrange
final int WORK_FACTOR = 10;
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(WORK_FACTOR);
final String PLAINTEXT = "This is a plaintext message.";
// These values can be generated by running `$ ./openssl_bcrypt` in the terminal
// The Ruby bcrypt gem does not expose the custom Radix64 decoder, so maintain the R64 encoding from the output and decode here
final byte[] SALT = new Radix64Encoder.Default().decode("LBVzJoPgh.85YCvnos4BKO".getBytes());
final byte[] IV = Hex.decodeHex("bae8a9d935748a75ff0e0bbd95a4f024".toCharArray());
// $v2$w2$base64_salt_22__base64_hash_31
final String FULL_HASH = "$2a$10$LBVzJoPgh.85YCvnos4BKOyYM.LRni6UbU4v/CEPBkmFIiigADJZi";
final String CIPHER_TEXT = "d232b68e7aa38242d195c54b8f360d8b8d6b7580b190ffdeef99f5fe460bd6b0";
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT.toCharArray());
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Sanity check
Cipher rubyCipher = Cipher.getInstance(encryptionMethod.getAlgorithm(), "BC");
SecretKeySpec rubyKey = new SecretKeySpec(Hex.decodeHex("01ea96ccc48a1d045bd7f461721b94a8".toCharArray()), "AES");
IvParameterSpec ivSpec = new IvParameterSpec(IV);
rubyCipher.init(Cipher.ENCRYPT_MODE, rubyKey, ivSpec);
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.getBytes());
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec);
assertArrayEquals(PLAINTEXT.getBytes(), rubyCipher.doFinal(rubyCipherBytes));
assertArrayEquals(PLAINTEXT.getBytes(), rubyCipher.doFinal(cipherBytes));
// Sanity for hash generation
final String FULL_SALT = FULL_HASH.substring(0, 29);
String generatedHash = new String(BCrypt.withDefaults().hash(WORK_FACTOR, BcryptCipherProvider.extractRawSalt(FULL_SALT), BAD_PASSWORD.getBytes()));
assertEquals(FULL_HASH, generatedHash);
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, FULL_SALT.getBytes(), IV, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
private static byte[] customB64Decode(String input) {
return customB64Decode(input.getBytes());
}
private static byte[] customB64Decode(byte[] input) {
return new Radix64Encoder.Default().decode(input);
}
private static String customB64Encode(String input) {
return customB64Encode(input.getBytes());
}
private static String customB64Encode(byte[] input) {
return new String(new Radix64Encoder.Default().encode(input), StandardCharsets.UTF_8);
}
@Test
void testGetCipherShouldHandleFullSalt() throws Exception {
// Arrange
final int WORK_FACTOR = 10;
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(WORK_FACTOR);
final String PLAINTEXT = "This is a plaintext message.";
// These values can be generated by running `$ ./openssl_bcrypt.rb` in the terminal
final byte[] IV = Hex.decodeHex("bae8a9d935748a75ff0e0bbd95a4f024".toCharArray());
// $v2$w2$base64_salt_22__base64_hash_31
final String FULL_HASH = "$2a$10$LBVzJoPgh.85YCvnos4BKOyYM.LRni6UbU4v/CEPBkmFIiigADJZi";
final String FULL_SALT = FULL_HASH.substring(0, 29);
final String CIPHER_TEXT = "d232b68e7aa38242d195c54b8f360d8b8d6b7580b190ffdeef99f5fe460bd6b0";
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT.toCharArray());
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, FULL_SALT.getBytes(), IV, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
@Test
void testGetCipherShouldHandleUnformedSalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4);
final List<String> INVALID_SALTS = Arrays.asList("$ab$00$acbdefghijklmnopqrstuv", "bad_salt", "$3a$11$", "x", "$2a$10$");
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final String salt : INVALID_SALTS) {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, salt.getBytes(), DEFAULT_KEY_LENGTH, true));
// Assert
assertTrue(iae.getMessage().contains("The salt must be of the format $2a$10$gUVbkVzp79H8YaCOsCVZNu. To generate a salt, use BcryptCipherProvider#generateSalt"));
}
}
@Test
void testGetCipherShouldRejectEmptySalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4);
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Two different errors -- one explaining the no-salt method is not supported, and the other for an empty byte[] passed
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true));
// Assert
assertTrue(iae.getMessage().contains("format"));
}
@Test
void testGetCipherForDecryptShouldRequireIV() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4);
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("00", 16).toCharArray());
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, DEFAULT_KEY_LENGTH, false));
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"));
}
}
@Test
void testGetCipherShouldAcceptValidKeyLengths() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4);
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("01", 16).toCharArray());
// Currently only AES ciphers are compatible with Bcrypt, so redundant to test all algorithms
final List<Integer> VALID_KEY_LENGTHS = AES_KEY_LENGTHS;
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final int keyLength : VALID_KEY_LENGTHS) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherShouldNotAcceptInvalidKeyLengths() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4);
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("00", 16).toCharArray());
// Currently only AES ciphers are compatible with Bcrypt, so redundant to test all algorithms
final List<Integer> INVALID_KEY_LENGTHS = Arrays.asList(-1, 40, 64, 112, 512);
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final int keyLength : INVALID_KEY_LENGTHS) {
// Initialize a cipher for encryption
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, true));
// Assert
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"));
}
}
@Test
void testGenerateSaltShouldUseProvidedWorkFactor() throws Exception {
// Arrange
BcryptCipherProvider cipherProvider = new BcryptCipherProvider(11);
int workFactor = cipherProvider.getWorkFactor();
// Act
final byte[] saltBytes = cipherProvider.generateSalt();
String salt = new String(saltBytes);
// Assert
final Matcher matcher = Pattern.compile("^\\$2[axy]\\$\\d{2}\\$").matcher(salt);
assertTrue(matcher.find());
assertTrue(salt.contains("$" + workFactor + "$"));
}
/**
* For {@code 1.12.0} the key derivation process was changed. Previously, the entire hash output
* ({@code $2a$10$9XUQnxGEUsRdLqEhxY3xNujOQQkW3spKqxssi.Ox39VhhxB.z4496}) was fed to {@code SHA-512}
* to stretch the hash output to a custom key length (128, 192, or 256 bits) because the Bcrypt hash
* output length is fixed at 184 bits. The new key derivation process only feeds the <em>non-salt
* hash output</em> (({@code jOQQkW3spKqxssi.Ox39VhhxB.z4496})) into the digest.
*/
@Test
void testGetCipherShouldUseHashOutputOnlyToDeriveKey() throws Exception {
// Arrange
BcryptCipherProvider cipherProvider = new BcryptCipherProvider(4);
final byte[] SALT = cipherProvider.generateSalt();
String saltString = new String(SALT, StandardCharsets.UTF_8);
// Determine the expected key bytes using the new key derivation process
BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher(cipherProvider.getWorkFactor(), cipherProvider.getDefaultSaltLength());
byte[] rawSaltBytes = BcryptCipherProvider.extractRawSalt(saltString);
byte[] hashOutputBytes = bcryptSecureHasher.hashRaw(SHORT_PASSWORD.getBytes(StandardCharsets.UTF_8), rawSaltBytes);
MessageDigest sha512 = MessageDigest.getInstance("SHA-512", "BC");
byte[] keyDigestBytes = sha512.digest(Arrays.copyOfRange(hashOutputBytes, hashOutputBytes.length - 31, hashOutputBytes.length));
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, DEFAULT_KEY_LENGTH, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Expected key verification
int keyLength = CipherUtility.parseKeyLengthFromAlgorithm(em.getAlgorithm());
byte[] derivedKeyBytes = Arrays.copyOf(keyDigestBytes, keyLength / 8);
Cipher verificationCipher = Cipher.getInstance(em.getAlgorithm());
verificationCipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(derivedKeyBytes, em.getAlgorithm()), new IvParameterSpec(iv));
byte[] verificationBytes = verificationCipher.doFinal(cipherBytes);
String verificationRecovered = new String(verificationBytes, StandardCharsets.UTF_8);
// Assert
assertEquals(PLAINTEXT, recovered);
assertEquals(PLAINTEXT, verificationRecovered);
}
}
@Test
void testGetCipherShouldBeBackwardCompatibleWithFullHashKeyDerivation() throws Exception {
// Arrange
BcryptCipherProvider cipherProvider = new BcryptCipherProvider(4);
final byte[] SALT = cipherProvider.generateSalt();
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption using the legacy key derivation process
Cipher cipher = cipherProvider.getInitializedCipher(em, SHORT_PASSWORD, SALT, new byte[0], DEFAULT_KEY_LENGTH, true, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getLegacyDecryptCipher(em, SHORT_PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherShouldHandleNullSalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new BcryptCipherProvider(4);
final String PASSWORD = "shortPassword";
final byte[] SALT = null;
final EncryptionMethod em = EncryptionMethod.AES_CBC;
// Act
// Initialize a cipher for encryption
IllegalArgumentException encryptIae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true));
IllegalArgumentException decryptIae = assertThrows(IllegalArgumentException.class, () -> {
final byte[] iv = new byte[16];
Arrays.fill(iv, (byte) '\0');
cipherProvider.getCipher(em, PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false);
});
// Assert
assertTrue(encryptIae.getMessage().contains("The salt must be of the format"));
assertTrue(decryptIae.getMessage().contains("The salt must be of the format"));
}
}

View File

@ -0,0 +1,297 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import at.favre.lib.crypto.bcrypt.Radix64Encoder;
import org.bouncycastle.util.encoders.Hex;
import org.junit.jupiter.api.Test;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class BcryptSecureHasherTest {
@Test
void testShouldBeDeterministicWithStaticSalt() {
// Arrange
int cost = 4;
int testIterations = 10;
byte[] inputBytes = "This is a sensitive value".getBytes();
final String EXPECTED_HASH_HEX = "24326124303424526b6a4559512f526245447959554b6553304471622e596b4c5331655a2e6c61586550484c69464d783937564c566d47354250454f";
BcryptSecureHasher bcryptSH = new BcryptSecureHasher(cost);
final List<String> results = new ArrayList<>();
// Act
for (int i = 0; i < testIterations; i++) {
byte[] hash = bcryptSH.hashRaw(inputBytes);
String hashHex = new String(Hex.encode(hash));
results.add(hashHex);
}
// Assert
results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result));
}
@Test
void testShouldBeDifferentWithRandomSalt() {
// Arrange
int cost = 4;
int saltLength = 16;
int testIterations = 10;
byte[] inputBytes = "This is a sensitive value".getBytes();
final String EXPECTED_HASH_HEX = "24326124303424546d6c47615342546447463061574d6755324673642e38675a347a6149356d6b4d50594c542e344e68337962455a4678384b676a75";
BcryptSecureHasher bcryptSH = new BcryptSecureHasher(cost, saltLength);
final List<String> results = new ArrayList<>();
// Act
for (int i = 0; i < testIterations; i++) {
byte[] hash = bcryptSH.hashRaw(inputBytes);
String hashHex = new String(Hex.encode(hash));
results.add(hashHex);
}
// Assert
assertEquals(results.size(), results.stream().distinct().collect(Collectors.toList()).size());
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result));
}
@Test
void testShouldHandleArbitrarySalt() {
// Arrange
int cost = 4;
final String input = "This is a sensitive value";
byte[] inputBytes = input.getBytes();
final String EXPECTED_HASH_HEX = "24326124303424526b6a4559512f526245447959554b6553304471622e596b4c5331655a2e6c61586550484c69464d783937564c566d47354250454f";
final String EXPECTED_HASH_BASE64 = "JDJhJDA0JFJrakVZUS9SYkVEeVlVS2VTMERxYi5Za0xTMWVaLmxhWGVQSExpRk14OTdWTFZtRzVCUEVP";
final byte[] EXPECTED_HASH_BYTES = Hex.decode(EXPECTED_HASH_HEX);
// Static salt instance
BcryptSecureHasher staticSaltHasher = new BcryptSecureHasher(cost);
BcryptSecureHasher arbitrarySaltHasher = new BcryptSecureHasher(cost, 16);
final byte[] STATIC_SALT = "NiFi Static Salt".getBytes(StandardCharsets.UTF_8);
final String DIFFERENT_STATIC_SALT = "Diff Static Salt";
// Act
byte[] staticSaltHash = staticSaltHasher.hashRaw(inputBytes);
byte[] arbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, STATIC_SALT);
byte[] differentArbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, DIFFERENT_STATIC_SALT.getBytes(StandardCharsets.UTF_8));
byte[] differentSaltHash = arbitrarySaltHasher.hashRaw(inputBytes);
String staticSaltHashHex = staticSaltHasher.hashHex(input);
String arbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8));
String differentArbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, DIFFERENT_STATIC_SALT);
String differentSaltHashHex = arbitrarySaltHasher.hashHex(input);
String staticSaltHashBase64 = staticSaltHasher.hashBase64(input);
String arbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8));
String differentArbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, DIFFERENT_STATIC_SALT);
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input);
// Assert
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash);
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash);
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash));
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash));
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex);
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex);
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex);
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex);
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64);
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64);
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64);
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64);
}
@Test
void testShouldValidateArbitrarySalt() {
// Arrange
int cost = 4;
final String input = "This is a sensitive value";
byte[] inputBytes = input.getBytes();
// Static salt instance
BcryptSecureHasher secureHasher = new BcryptSecureHasher(cost, 16);
final byte[] STATIC_SALT = "bad_sal".getBytes();
assertThrows(IllegalArgumentException.class, () -> new BcryptSecureHasher(cost, 7));
assertThrows(RuntimeException.class, () -> secureHasher.hashRaw(inputBytes, STATIC_SALT));
assertThrows(RuntimeException.class, () -> secureHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8)));
assertThrows(RuntimeException.class, () -> secureHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8)));
}
@Test
void testShouldFormatHex() {
// Arrange
String input = "This is a sensitive value";
final String EXPECTED_HASH_HEX = "24326124313224526b6a4559512f526245447959554b6553304471622e5852696135344d4e356c5a44515243575874516c4c696d476669635a776871";
BcryptSecureHasher bcryptSH = new BcryptSecureHasher();
// Act
String hashHex = bcryptSH.hashHex(input);
// Assert
assertEquals(EXPECTED_HASH_HEX, hashHex);
}
@Test
void testShouldFormatBase64() {
// Arrange
String input = "This is a sensitive value";
final String EXPECTED_HASH_BASE64 = "JDJhJDEyJFJrakVZUS9SYkVEeVlVS2VTMERxYi5YUmlhNTRNTjVsWkRRUkNXWHRRbExpbUdmaWNad2hx";
BcryptSecureHasher bcryptSH = new BcryptSecureHasher();
// Act
String hashB64 = bcryptSH.hashBase64(input);
// Assert
assertEquals(EXPECTED_HASH_BASE64, hashB64);
}
@Test
void testShouldHandleNullInput() {
// Arrange
List<String> inputs = Arrays.asList(null, "");
final String EXPECTED_HASH_HEX = "";
final String EXPECTED_HASH_BASE64 = "";
BcryptSecureHasher bcryptSH = new BcryptSecureHasher();
final List<String> hexResults = new ArrayList<>();
final List<String> b64Results = new ArrayList<>();
// Act
for (final String input : inputs) {
String hashHex = bcryptSH.hashHex(input);
hexResults.add(hashHex);
String hashB64 = bcryptSH.hashBase64(input);
b64Results.add(hashB64);
}
// Assert
hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result));
b64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result));
}
@Test
void testShouldVerifyCostBoundary() throws Exception {
// Arrange
final int cost = 14;
// Act and Assert
assertTrue(BcryptSecureHasher.isCostValid(cost));
}
@Test
void testShouldFailCostBoundary() throws Exception {
// Arrange
final List<Integer> costFactors = Arrays.asList(-8, 0, 40);
// Act and Assert
costFactors.forEach(costFactor -> assertFalse(BcryptSecureHasher.isCostValid(costFactor)));
}
@Test
void testShouldVerifySaltLengthBoundary() throws Exception {
// Arrange
final List<Integer> saltLengths = Arrays.asList(0, 16);
// Act and Assert
BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher();
saltLengths.forEach(saltLength -> assertTrue(bcryptSecureHasher.isSaltLengthValid(saltLength)));
}
@Test
void testShouldFailSaltLengthBoundary() throws Exception {
// Arrange
final List<Integer> saltLengths = Arrays.asList(-8, 1);
// Act and Assert
BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher();
saltLengths.forEach(saltLength -> assertFalse(bcryptSecureHasher.isSaltLengthValid(saltLength)));
}
@Test
void testShouldConvertRadix64ToBase64() {
// Arrange
final String INPUT_RADIX_64 = "mm7MiKjvXVYCujVUlKRKiu";
final byte[] EXPECTED_BYTES = new Radix64Encoder.Default().decode(INPUT_RADIX_64.getBytes());
// Uses standard Base64 library but removes padding chars
final String EXPECTED_MIME_B64 = Base64.getEncoder().encodeToString(EXPECTED_BYTES).replaceAll("=", "");
// Act
String convertedBase64 = BcryptSecureHasher.convertBcryptRadix64ToMimeBase64(INPUT_RADIX_64);
String convertedRadix64 = BcryptSecureHasher.convertMimeBase64ToBcryptRadix64(convertedBase64);
// Assert
assertEquals(EXPECTED_MIME_B64, convertedBase64);
assertEquals(INPUT_RADIX_64, convertedRadix64);
}
@Test
void testConvertRadix64ToBase64ShouldHandlePeriod() {
// Arrange
final String INPUT_RADIX_64 = "75x373yP7atxMD3pVgsdO.";
final byte[] EXPECTED_BYTES = new Radix64Encoder.Default().decode(INPUT_RADIX_64.getBytes());
// Uses standard Base64 library but removes padding chars
final String EXPECTED_MIME_B64 = Base64.getEncoder().encodeToString(EXPECTED_BYTES).replaceAll("=", "");
// Act
String convertedBase64 = BcryptSecureHasher.convertBcryptRadix64ToMimeBase64(INPUT_RADIX_64);
String convertedRadix64 = BcryptSecureHasher.convertMimeBase64ToBcryptRadix64(convertedBase64);
// Assert
assertEquals(EXPECTED_MIME_B64, convertedBase64);
assertEquals(INPUT_RADIX_64, convertedRadix64);
}
}

View File

@ -0,0 +1,322 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import org.apache.nifi.security.util.EncryptionMethod;
import org.apache.nifi.security.util.KeyDerivationFunction;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class CipherUtilityTest {
private static final Pattern KEY_LENGTH_PATTERN = Pattern.compile("([\\d]+)BIT");
// TripleDES must precede DES for automatic grouping precedence
private static final List<String> CIPHERS = Arrays.asList("AES", "TRIPLEDES", "DES", "RC2", "RC4", "RC5", "TWOFISH");
private static final List<String> SYMMETRIC_ALGORITHMS = Arrays.stream(EncryptionMethod.values())
.map(it -> it.getAlgorithm())
.filter(algorithm -> algorithm.startsWith("PBE") || algorithm.startsWith("AES"))
.collect(Collectors.toList());
private static final Map<String, List<String>> ALGORITHMS_MAPPED_BY_CIPHER = SYMMETRIC_ALGORITHMS
.stream()
.collect(Collectors.groupingBy(algorithm -> CIPHERS.stream().filter(cipher -> algorithm.contains(cipher)).findFirst().get()));
// Manually mapped as of 03/21/21 1.13.0
private static final Map<Integer, List<String>> ALGORITHMS_MAPPED_BY_KEY_LENGTH = new HashMap<>();
static {
ALGORITHMS_MAPPED_BY_KEY_LENGTH.put(40, Arrays.asList("PBEWITHSHAAND40BITRC2-CBC",
"PBEWITHSHAAND40BITRC4"));
ALGORITHMS_MAPPED_BY_KEY_LENGTH.put(64, Arrays.asList("PBEWITHMD5ANDDES",
"PBEWITHSHA1ANDDES"));
ALGORITHMS_MAPPED_BY_KEY_LENGTH.put(112, Arrays.asList("PBEWITHSHAAND2-KEYTRIPLEDES-CBC",
"PBEWITHSHAAND3-KEYTRIPLEDES-CBC"));
ALGORITHMS_MAPPED_BY_KEY_LENGTH.put(128, Arrays.asList("PBEWITHMD5AND128BITAES-CBC-OPENSSL",
"PBEWITHMD5ANDRC2",
"PBEWITHSHA1ANDRC2",
"PBEWITHSHA256AND128BITAES-CBC-BC",
"PBEWITHSHAAND128BITAES-CBC-BC",
"PBEWITHSHAAND128BITRC2-CBC",
"PBEWITHSHAAND128BITRC4",
"PBEWITHSHAANDTWOFISH-CBC",
"AES/CBC/NoPadding",
"AES/CBC/PKCS7Padding",
"AES/CTR/NoPadding",
"AES/GCM/NoPadding"));
ALGORITHMS_MAPPED_BY_KEY_LENGTH.put(192, Arrays.asList("PBEWITHMD5AND192BITAES-CBC-OPENSSL",
"PBEWITHSHA256AND192BITAES-CBC-BC",
"PBEWITHSHAAND192BITAES-CBC-BC",
"AES/CBC/NoPadding",
"AES/CBC/PKCS7Padding",
"AES/CTR/NoPadding",
"AES/GCM/NoPadding"));
ALGORITHMS_MAPPED_BY_KEY_LENGTH.put(256, Arrays.asList("PBEWITHMD5AND256BITAES-CBC-OPENSSL",
"PBEWITHSHA256AND256BITAES-CBC-BC",
"PBEWITHSHAAND256BITAES-CBC-BC",
"AES/CBC/NoPadding",
"AES/CBC/PKCS7Padding",
"AES/CTR/NoPadding",
"AES/GCM/NoPadding"));
}
@BeforeAll
static void setUpOnce() {
Security.addProvider(new BouncyCastleProvider());
// Fix because TRIPLEDES -> DESede
final List<String> tripleDESAlgorithms = ALGORITHMS_MAPPED_BY_CIPHER.remove("TRIPLEDES");
ALGORITHMS_MAPPED_BY_CIPHER.put("DESede", tripleDESAlgorithms);
}
@Test
void testShouldParseCipherFromAlgorithm() {
// Arrange
final Map<String, List<String>> EXPECTED_ALGORITHMS = ALGORITHMS_MAPPED_BY_CIPHER;
// Act
for (final String algorithm: SYMMETRIC_ALGORITHMS) {
String cipher = CipherUtility.parseCipherFromAlgorithm(algorithm);
// Assert
assertTrue(EXPECTED_ALGORITHMS.get(cipher).contains(algorithm));
}
}
@Test
void testShouldParseKeyLengthFromAlgorithm() {
// Arrange
final Map<Integer, List<String>> EXPECTED_ALGORITHMS = ALGORITHMS_MAPPED_BY_KEY_LENGTH;
// Act
for (final String algorithm: SYMMETRIC_ALGORITHMS) {
int keyLength = CipherUtility.parseKeyLengthFromAlgorithm(algorithm);
// Assert
assertTrue(EXPECTED_ALGORITHMS.get(keyLength).contains(algorithm));
}
}
@Test
void testShouldDetermineValidKeyLength() {
// Arrange
// Act
for (final Map.Entry<Integer, List<String>> entry : ALGORITHMS_MAPPED_BY_KEY_LENGTH.entrySet()) {
final int keyLength = entry.getKey();
final List<String> algorithms = entry.getValue();
for (final String algorithm : algorithms) {
// Assert
assertTrue(CipherUtility.isValidKeyLength(keyLength, CipherUtility.parseCipherFromAlgorithm(algorithm)));
}
}
}
@Test
void testShouldDetermineInvalidKeyLength() {
// Arrange
// Act
for (final Map.Entry<Integer, List<String>> entry : ALGORITHMS_MAPPED_BY_KEY_LENGTH.entrySet()) {
final int keyLength = entry.getKey();
final List<String> algorithms = entry.getValue();
for (final String algorithm : algorithms) {
final List<Integer> invalidKeyLengths = new ArrayList<>(Arrays.asList(-1, 0, 1));
final Matcher matcher = Pattern.compile("RC\\d").matcher(algorithm);
if (matcher.find()) {
invalidKeyLengths.add(39);
invalidKeyLengths.add(2049);
} else {
invalidKeyLengths.add(keyLength + 1);
}
// Assert
invalidKeyLengths.forEach(invalidKeyLength -> assertFalse(CipherUtility.isValidKeyLength(invalidKeyLength, CipherUtility.parseCipherFromAlgorithm(algorithm))));
}
}
}
@Test
void testShouldDetermineValidKeyLengthForAlgorithm() {
// Arrange
// Act
for (final Map.Entry<Integer, List<String>> entry : ALGORITHMS_MAPPED_BY_KEY_LENGTH.entrySet()) {
final int keyLength = entry.getKey();
final List<String> algorithms = entry.getValue();
for (final String algorithm : algorithms) {
// Assert
assertTrue(CipherUtility.isValidKeyLengthForAlgorithm(keyLength, algorithm));
}
}
}
@Test
void testShouldDetermineInvalidKeyLengthForAlgorithm() {
// Arrange
// Act
for (final Map.Entry<Integer, List<String>> entry : ALGORITHMS_MAPPED_BY_KEY_LENGTH.entrySet()) {
final int keyLength = entry.getKey();
final List<String> algorithms = entry.getValue();
for (final String algorithm : algorithms) {
final List<Integer> invalidKeyLengths = new ArrayList<>(Arrays.asList(-1, 0, 1));
final Matcher matcher = Pattern.compile("RC\\d").matcher(algorithm);
if (matcher.find()) {
invalidKeyLengths.add(39);
invalidKeyLengths.add(2049);
} else {
invalidKeyLengths.add(keyLength + 1);
}
// Assert
invalidKeyLengths.forEach(invalidKeyLength -> assertFalse(CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm)));
}
}
// Extra hard-coded checks
String algorithm = "PBEWITHSHA256AND256BITAES-CBC-BC";
int invalidKeyLength = 192;
assertFalse(CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm));
}
@Test
void testShouldGetValidKeyLengthsForAlgorithm() {
// Arrange
final List<Integer> rcKeyLengths = IntStream.rangeClosed(40, 2048)
.boxed().collect(Collectors.toList());
final Map<String, List<Integer>> CIPHER_KEY_SIZES = new HashMap<>();
CIPHER_KEY_SIZES.put("AES", Arrays.asList(128, 192, 256));
CIPHER_KEY_SIZES.put("DES", Arrays.asList(56, 64));
CIPHER_KEY_SIZES.put("DESede", Arrays.asList(56, 64, 112, 128, 168, 192));
CIPHER_KEY_SIZES.put("RC2", rcKeyLengths);
CIPHER_KEY_SIZES.put("RC4", rcKeyLengths);
CIPHER_KEY_SIZES.put("RC5", rcKeyLengths);
CIPHER_KEY_SIZES.put("TWOFISH", Arrays.asList(128, 192, 256));
final List<String> SINGLE_KEY_SIZE_ALGORITHMS = Arrays.stream(EncryptionMethod.values())
.map(encryptionMethod -> encryptionMethod.getAlgorithm())
.filter(algorithm -> parseActualKeyLengthFromAlgorithm(algorithm) != -1)
.collect(Collectors.toList());
final List<String> MULTIPLE_KEY_SIZE_ALGORITHMS = Arrays.stream(EncryptionMethod.values())
.map(encryptionMethod -> encryptionMethod.getAlgorithm())
.filter(algorithm -> !algorithm.contains("PGP"))
.collect(Collectors.toList());
MULTIPLE_KEY_SIZE_ALGORITHMS.removeAll(SINGLE_KEY_SIZE_ALGORITHMS);
// Act
for (final String algorithm : SINGLE_KEY_SIZE_ALGORITHMS) {
final List<Integer> EXPECTED_KEY_SIZES = Arrays.asList(CipherUtility.parseKeyLengthFromAlgorithm(algorithm));
final List<Integer> validKeySizes = CipherUtility.getValidKeyLengthsForAlgorithm(algorithm);
// Assert
assertEquals(EXPECTED_KEY_SIZES, validKeySizes);
}
// Act
for (final String algorithm : MULTIPLE_KEY_SIZE_ALGORITHMS) {
final String cipher = CipherUtility.parseCipherFromAlgorithm(algorithm);
final List<Integer> EXPECTED_KEY_SIZES = CIPHER_KEY_SIZES.get(cipher);
final List<Integer> validKeySizes = CipherUtility.getValidKeyLengthsForAlgorithm(algorithm);
// Assert
assertEquals(EXPECTED_KEY_SIZES, validKeySizes);
}
}
@Test
void testShouldFindSequence() {
// Arrange
byte[] license = ("Licensed to the Apache Software Foundation (ASF) under one or more " +
"contributor license agreements. See the NOTICE file distributed with " +
"this work for additional information regarding copyright ownership. " +
"The ASF licenses this file to You under the Apache License, Version 2.0 " +
"(the \"License\"); you may not use this file except in compliance with " +
"the License. You may obtain a copy of the License at ".getBytes()).getBytes();
byte[] apache = "Apache".getBytes();
byte[] software = "Software".getBytes();
byte[] asf = "ASF".getBytes();
byte[] kafka = "Kafka".getBytes();
// Act
int apacheIndex = CipherUtility.findSequence(license, apache);
int softwareIndex = CipherUtility.findSequence(license, software);
int asfIndex = CipherUtility.findSequence(license, asf);
int kafkaIndex = CipherUtility.findSequence(license, kafka);
// Assert
assertEquals(16, apacheIndex);
assertEquals(23, softwareIndex);
assertEquals(44, asfIndex);
assertEquals(-1, kafkaIndex);
}
@Test
void testShouldExtractRawSalt() {
// Arrange
final byte[] PLAIN_SALT = new byte[16];
Arrays.fill(PLAIN_SALT, (byte) 0xab);
String ARGON2_SALT = Argon2CipherProvider.formSalt(PLAIN_SALT, 8, 1, 1);
String BCRYPT_SALT = BcryptCipherProvider.formatSaltForBcrypt(PLAIN_SALT, 10);
String SCRYPT_SALT = ScryptCipherProvider.formatSaltForScrypt(PLAIN_SALT, 10, 1, 1);
// Act
final Map<Object, byte[]> results = Arrays.stream(KeyDerivationFunction.values())
.filter(kdf -> !kdf.isStrongKDF())
.collect(Collectors.toMap(Function.identity(), kdf -> CipherUtility.extractRawSalt(PLAIN_SALT, kdf)));
results.put(KeyDerivationFunction.ARGON2, CipherUtility.extractRawSalt(ARGON2_SALT.getBytes(), KeyDerivationFunction.ARGON2));
results.put(KeyDerivationFunction.BCRYPT, CipherUtility.extractRawSalt(BCRYPT_SALT.getBytes(), KeyDerivationFunction.BCRYPT));
results.put(KeyDerivationFunction.SCRYPT, CipherUtility.extractRawSalt(SCRYPT_SALT.getBytes(), KeyDerivationFunction.SCRYPT));
results.put(KeyDerivationFunction.PBKDF2, CipherUtility.extractRawSalt(PLAIN_SALT, KeyDerivationFunction.PBKDF2));
// Assert
results.values().forEach(v -> assertArrayEquals(PLAIN_SALT, v));
}
private static int parseActualKeyLengthFromAlgorithm(final String algorithm) {
Matcher matcher = KEY_LENGTH_PATTERN.matcher(algorithm);
if (matcher.find()) {
return Integer.parseInt(matcher.group(1));
} else {
return -1;
}
}
}

View File

@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License") you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class HashAlgorithmTest {
@Test
void testDetermineBrokenAlgorithms() throws Exception {
// Arrange
final List<HashAlgorithm> algorithms = Arrays.asList(HashAlgorithm.values());
// Act
final List<HashAlgorithm> brokenAlgorithms = algorithms.stream()
.filter(algorithm -> !algorithm.isStrongAlgorithm())
.collect(Collectors.toList());
// Assert
assertEquals(Arrays.asList(HashAlgorithm.MD2, HashAlgorithm.MD5, HashAlgorithm.SHA1), brokenAlgorithms);
}
@Test
void testShouldBuildAllowableValueDescription() {
// Arrange
final List<HashAlgorithm> algorithms = Arrays.asList(HashAlgorithm.values());
// Act
final List<String> descriptions = algorithms.stream()
.map(algorithm -> algorithm.buildAllowableValueDescription())
.collect(Collectors.toList());
// Assert
descriptions.forEach(description -> {
final Pattern pattern = Pattern.compile(".* \\(\\d+ byte output\\).*");
final Matcher matcher = pattern.matcher(description);
assertTrue(matcher.find());
});
descriptions.stream()
.filter(description -> {
final Pattern pattern = Pattern.compile("MD2|MD5|SHA-1");
final Matcher matcher = pattern.matcher(description);
return matcher.find();
})
.forEach(description -> assertTrue(description.contains("WARNING")));
}
@Test
void testDetermineBlake2Algorithms() {
final List<HashAlgorithm> algorithms = Arrays.asList(HashAlgorithm.values());
// Act
final List<HashAlgorithm> blake2Algorithms = algorithms.stream()
.filter(HashAlgorithm::isBlake2)
.collect(Collectors.toList());
// Assert
assertEquals(Arrays.asList(HashAlgorithm.BLAKE2_160, HashAlgorithm.BLAKE2_256, HashAlgorithm.BLAKE2_384, HashAlgorithm.BLAKE2_512), blake2Algorithms);
}
@Test
void testShouldMatchAlgorithmByName() {
// Arrange
final List<HashAlgorithm> algorithms = Arrays.asList(HashAlgorithm.values());
// Act
for (final HashAlgorithm algorithm : algorithms) {
final List<String> transformedNames = Arrays.asList(algorithm.getName(), algorithm.getName().toUpperCase(), algorithm.getName().toLowerCase());
for (final String name : transformedNames) {
HashAlgorithm found = HashAlgorithm.fromName(name);
// Assert
assertEquals(name.toUpperCase(), found.getName());
}
}
}
}

View File

@ -0,0 +1,398 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License") you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import org.apache.nifi.components.AllowableValue;
import org.apache.nifi.util.StringUtils;
import org.bouncycastle.util.encoders.Hex;
import org.junit.jupiter.api.Test;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class HashServiceTest {
private static final String KNOWN_VALUE = "apachenifi";
@Test
void testShouldHashValue() {
// Arrange
final HashAlgorithm algorithm = HashAlgorithm.SHA256;
final String EXPECTED_HASH = "dc4bd945723b9c234f1be408e8ceb78660b481008b8ab5b71eb2aa3b4f08357a";
final byte[] EXPECTED_HASH_BYTES = Hex.decode(EXPECTED_HASH);
String threeArgString = HashService.hashValue(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8);
String twoArgString = HashService.hashValue(algorithm, KNOWN_VALUE);
byte[] threeArgStringRaw = HashService.hashValueRaw(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8);
byte[] twoArgStringRaw = HashService.hashValueRaw(algorithm, KNOWN_VALUE);
byte[] twoArgBytesRaw = HashService.hashValueRaw(algorithm, KNOWN_VALUE.getBytes());
final Map<String, Object> scenarios = new HashMap<>();
scenarios.put("threeArgString", threeArgString);
scenarios.put("twoArgString", twoArgString);
scenarios.put("threeArgStringRaw", threeArgStringRaw);
scenarios.put("twoArgStringRaw", twoArgStringRaw);
scenarios.put("twoArgBytesRaw", twoArgBytesRaw);
// Act
for (final Object result : scenarios.values()) {
// Assert
if (result instanceof byte[]) {
assertArrayEquals(EXPECTED_HASH_BYTES, (byte[]) result);
} else {
assertEquals(EXPECTED_HASH, result);
}
}
}
@Test
void testHashValueShouldDifferOnDifferentEncodings() {
// Arrange
final HashAlgorithm algorithm = HashAlgorithm.SHA256;
// Act
String utf8Hash = HashService.hashValue(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8);
String utf16Hash = HashService.hashValue(algorithm, KNOWN_VALUE, StandardCharsets.UTF_16);
// Assert
assertNotEquals(utf8Hash, utf16Hash);
}
/**
* This test ensures that the service properly handles UTF-16 encoded data to return it without
* the Big Endian Byte Order Mark (BOM). Java treats UTF-16 encoded data without a BOM as Big Endian by default on decoding, but when <em>encoding</em>, it inserts a BE BOM in the data.
*
* Examples:
*
* "apachenifi"
*
* * UTF-8: 0x61 0x70 0x61 0x63 0x68 0x65 0x6E 0x69 0x66 0x69
* * UTF-16: 0xFE 0xFF 0x00 0x61 0x00 0x70 0x00 0x61 0x00 0x63 0x00 0x68 0x00 0x65 0x00 0x6E 0x00 0x69 0x00 0x66 0x00 0x69
* * UTF-16LE: 0x61 0x00 0x70 0x00 0x61 0x00 0x63 0x00 0x68 0x00 0x65 0x00 0x6E 0x00 0x69 0x00 0x66 0x00 0x69 0x00
* * UTF-16BE: 0x00 0x61 0x00 0x70 0x00 0x61 0x00 0x63 0x00 0x68 0x00 0x65 0x00 0x6E 0x00 0x69 0x00 0x66 0x00 0x69
*
* The result of "UTF-16" decoding should have the 0xFE 0xFF stripped on return by encoding in UTF-16BE directly, which will not insert a BOM.
*
* See also: <a href="https://unicode.org/faq/utf_bom.html#bom10">https://unicode.org/faq/utf_bom.html#bom10</a>
*/
@Test
void testHashValueShouldHandleUTF16BOMIssue() {
// Arrange
HashAlgorithm algorithm = HashAlgorithm.SHA256;
List<Charset> charsets = Arrays.asList(StandardCharsets.UTF_8, StandardCharsets.UTF_16, StandardCharsets.UTF_16LE, StandardCharsets.UTF_16BE);
final Map<String, String> EXPECTED_SHA_256_HASHES = new HashMap<>();
EXPECTED_SHA_256_HASHES.put("utf_8", "dc4bd945723b9c234f1be408e8ceb78660b481008b8ab5b71eb2aa3b4f08357a");
EXPECTED_SHA_256_HASHES.put("utf_16", "f370019c2a41a8285077beb839f7566240e2f0ca970cb67aed5836b89478df91");
EXPECTED_SHA_256_HASHES.put("utf_16be", "f370019c2a41a8285077beb839f7566240e2f0ca970cb67aed5836b89478df91");
EXPECTED_SHA_256_HASHES.put("utf_16le", "7e285dc64d3a8c3cb4e04304577eebbcb654f2245373874e48e597a8b8f15aff");
// Act
for (final Charset charset : charsets) {
// Calculate the expected hash value given the character set
String hash = HashService.hashValue(algorithm, KNOWN_VALUE, charset);
// Assert
assertEquals(EXPECTED_SHA_256_HASHES.get(translateStringToMapKey(charset.name())), hash);
}
}
@Test
void testHashValueShouldDefaultToUTF8() {
// Arrange
final HashAlgorithm algorithm = HashAlgorithm.SHA256;
// Act
String explicitUTF8Hash = HashService.hashValue(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8);
String implicitUTF8Hash = HashService.hashValue(algorithm, KNOWN_VALUE);
byte[] explicitUTF8HashBytes = HashService.hashValueRaw(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8);
byte[] implicitUTF8HashBytes = HashService.hashValueRaw(algorithm, KNOWN_VALUE);
byte[] implicitUTF8HashBytesDefault = HashService.hashValueRaw(algorithm, KNOWN_VALUE.getBytes());
// Assert
assertEquals(explicitUTF8Hash, implicitUTF8Hash);
assertArrayEquals(explicitUTF8HashBytes, implicitUTF8HashBytes);
assertArrayEquals(explicitUTF8HashBytes, implicitUTF8HashBytesDefault);
}
@Test
void testShouldRejectNullAlgorithm() {
// Arrange
final List<IllegalArgumentException> errors = new ArrayList<>();
errors.add(assertThrows(IllegalArgumentException.class,
() -> HashService.hashValue(null, KNOWN_VALUE, StandardCharsets.UTF_8)));
errors.add(assertThrows(IllegalArgumentException.class,
() -> HashService.hashValue(null, KNOWN_VALUE)));
errors.add(assertThrows(IllegalArgumentException.class,
() -> HashService.hashValueRaw(null, KNOWN_VALUE, StandardCharsets.UTF_8)));
errors.add(assertThrows(IllegalArgumentException.class,
() -> HashService.hashValueRaw(null, KNOWN_VALUE)));
errors.add(assertThrows(IllegalArgumentException.class,
() -> HashService.hashValueRaw(null, KNOWN_VALUE.getBytes())));
errors.forEach(error -> assertTrue(error.getMessage().contains("The hash algorithm cannot be null")));
}
@Test
void testShouldRejectNullValue() {
// Arrange
final HashAlgorithm algorithm = HashAlgorithm.SHA256;
final List<IllegalArgumentException> errors = new ArrayList<>();
// Act and Assert
errors.add(assertThrows(IllegalArgumentException.class,
() -> HashService.hashValue(algorithm, null, StandardCharsets.UTF_8)));
errors.add(assertThrows(IllegalArgumentException.class,
() -> HashService.hashValue(algorithm, null)));
errors.add(assertThrows(IllegalArgumentException.class,
() -> HashService.hashValueRaw(algorithm, null, StandardCharsets.UTF_8)));
errors.add(assertThrows(IllegalArgumentException.class,
() -> HashService.hashValueRaw(algorithm, (String) null)));
errors.add(assertThrows(IllegalArgumentException.class,
() -> HashService.hashValueRaw(algorithm, (byte[]) null)));
// Act
errors.forEach(error -> assertTrue(error.getMessage().contains("The value cannot be null")));
}
@Test
void testShouldHashConstantValue() throws Exception {
// Arrange
final List<HashAlgorithm> algorithms = Arrays.asList(HashAlgorithm.values());
/* These values were generated using command-line tools (openssl dgst -md5, shasum [-a 1 224 256 384 512 512224 512256], rhash --sha3-224, b2sum -l 224)
* Ex: {@code $ echo -n "apachenifi" | openssl dgst -md5}
*/
final Map<String, String> EXPECTED_HASHES = new HashMap<>();
EXPECTED_HASHES.put("md2", "25d261790198fa543b3436b4755ded91");
EXPECTED_HASHES.put("md5", "a968b5ec1d52449963dcc517789baaaf");
EXPECTED_HASHES.put("sha_1", "749806dbcab91a695ac85959aca610d84f03c6a7");
EXPECTED_HASHES.put("sha_224", "4933803881a4ccb9b3453b829263d3e44852765db12958267ad46135");
EXPECTED_HASHES.put("sha_256", "dc4bd945723b9c234f1be408e8ceb78660b481008b8ab5b71eb2aa3b4f08357a");
EXPECTED_HASHES.put("sha_384", "a5205271df448e55afc4a553e91a8fea7d60d080d390d1f3484fcb6318abe94174cf3d36ea4eb1a4d5ed7637c99dec0c");
EXPECTED_HASHES.put("sha_512", "0846ae23e122fbe090e94d45f886aa786acf426f56496e816a64e292b78c1bb7a962dbfd32c5c73bbee432db400970e22fd65498c862da72a305311332c6f302");
EXPECTED_HASHES.put("sha_512_224", "ecf78a026035528e3097ea7289257d1819d273f60636060fbba43bfb");
EXPECTED_HASHES.put("sha_512_256", "d90bdd8ad7e19f2d7848a45782d5dbe056a8213a94e03d9a35d6f44dbe7ee6cd");
EXPECTED_HASHES.put("sha3_224", "2e9d1ea677847dce686ca2444cc4525f114443652fcb55af4c7286cd");
EXPECTED_HASHES.put("sha3_256", "b1b3cd90a21ef60caba5ec1bf12ffcb833e52a0ae26f0ab7c4f9ccfa9c5c025b");
EXPECTED_HASHES.put("sha3_384", "ca699a2447032857bf4f7e84fa316264f0c1870f9330031d5d75a0770644353c268b36d0522a3cf62e60f9401aadc37c");
EXPECTED_HASHES.put("sha3_512", "cb9059d9b7ec4fde4d9710160a694e7ac2a4dd9969dee43d730066ded7b80d3eefdb4cae7622d21f6cfe16092e24f1ad6ca5924767118667654cf71b7abaaca4");
EXPECTED_HASHES.put("blake2_160", "7bc5a408dba4f1934d9090c4d75c65bfa0c7c90c");
EXPECTED_HASHES.put("blake2_256", "40b8935dc5ed153846fb08dac8e7999ba04a74f4dab28415c39847a15c211447");
EXPECTED_HASHES.put("blake2_384", "40716eddc8cfcf666d980804fed294c43fe9436a9787367a3086b45d69791fd5cef1a16c17235ea289c1e40a899b4f6b");
EXPECTED_HASHES.put("blake2_512", "5f34525b130c11c469302ef6734bf6eedb1eca5d7445a3c4ae289ab58dd13ef72531966bfe2f67c4bf49c99dd14dae92d245f241482307d29bf25c45a1085026");
// Act
final Map<String, String> generatedHashes = algorithms
.stream()
.collect(Collectors.toMap(HashAlgorithm::getName, algorithm -> HashService.hashValue(algorithm, KNOWN_VALUE, StandardCharsets.UTF_8)));
// Assert
for (final Map.Entry<String, String> entry : generatedHashes.entrySet()) {
final String algorithmName = entry.getKey();
final String hash = entry.getValue();
String key = translateStringToMapKey(algorithmName);
assertEquals(EXPECTED_HASHES.get(key), hash);
}
}
@Test
void testShouldHashEmptyValue() throws Exception {
// Arrange
final List<HashAlgorithm> algorithms = Arrays.asList(HashAlgorithm.values());
final String EMPTY_VALUE = "";
/* These values were generated using command-line tools (openssl dgst -md5, shasum [-a 1 224 256 384 512 512224 512256], rhash --sha3-224, b2sum -l 224)
* Ex: {@code $ echo -n "" | openssl dgst -md5}
*/
final Map<String, String> EXPECTED_HASHES = new HashMap<>();
EXPECTED_HASHES.put("md2", "8350e5a3e24c153df2275c9f80692773");
EXPECTED_HASHES.put("md5", "d41d8cd98f00b204e9800998ecf8427e");
EXPECTED_HASHES.put("sha_1", "da39a3ee5e6b4b0d3255bfef95601890afd80709");
EXPECTED_HASHES.put("sha_224", "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f");
EXPECTED_HASHES.put("sha_256", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
EXPECTED_HASHES.put("sha_384", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b");
EXPECTED_HASHES.put("sha_512", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e");
EXPECTED_HASHES.put("sha_512_224", "6ed0dd02806fa89e25de060c19d3ac86cabb87d6a0ddd05c333b84f4");
EXPECTED_HASHES.put("sha_512_256", "c672b8d1ef56ed28ab87c3622c5114069bdd3ad7b8f9737498d0c01ecef0967a");
EXPECTED_HASHES.put("sha3_224", "6b4e03423667dbb73b6e15454f0eb1abd4597f9a1b078e3f5b5a6bc7");
EXPECTED_HASHES.put("sha3_256", "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a");
EXPECTED_HASHES.put("sha3_384", "0c63a75b845e4f7d01107d852e4c2485c51a50aaaa94fc61995e71bbee983a2ac3713831264adb47fb6bd1e058d5f004");
EXPECTED_HASHES.put("sha3_512", "a69f73cca23a9ac5c8b567dc185a756e97c982164fe25859e0d1dcc1475c80a615b2123af1f5f94c11e3e9402c3ac558f500199d95b6d3e301758586281dcd26");
EXPECTED_HASHES.put("blake2_160", "3345524abf6bbe1809449224b5972c41790b6cf2");
EXPECTED_HASHES.put("blake2_256", "0e5751c026e543b2e8ab2eb06099daa1d1e5df47778f7787faab45cdf12fe3a8");
EXPECTED_HASHES.put("blake2_384", "b32811423377f52d7862286ee1a72ee540524380fda1724a6f25d7978c6fd3244a6caf0498812673c5e05ef583825100");
EXPECTED_HASHES.put("blake2_512", "786a02f742015903c6c6fd852552d272912f4740e15847618a86e217f71f5419d25e1031afee585313896444934eb04b903a685b1448b755d56f701afe9be2ce");
// Act
final Map<String, String> generatedHashes = algorithms
.stream()
.collect(Collectors.toMap(HashAlgorithm::getName, algorithm -> HashService.hashValue(algorithm, EMPTY_VALUE, StandardCharsets.UTF_8)));
// Assert
for (final Map.Entry<String, String> entry : generatedHashes.entrySet()) {
final String algorithmName = entry.getKey();
final String hash = entry.getValue();
String key = translateStringToMapKey(algorithmName);
assertEquals(EXPECTED_HASHES.get(key), hash);
}
}
@Test
void testShouldBuildHashAlgorithmAllowableValues() throws Exception {
// Arrange
final List<HashAlgorithm> EXPECTED_ALGORITHMS = Arrays.asList(HashAlgorithm.values());
// Act
final AllowableValue[] allowableValues = HashService.buildHashAlgorithmAllowableValues();
// Assert
assertInstanceOf(AllowableValue[].class, allowableValues);
final List<AllowableValue> valuesList = Arrays.asList(allowableValues);
assertEquals(EXPECTED_ALGORITHMS.size(), valuesList.size());
EXPECTED_ALGORITHMS.forEach(expectedAlgorithm -> {
final AllowableValue matchingValue = valuesList
.stream()
.filter(value -> value.getValue().equals(expectedAlgorithm.getName()))
.findFirst()
.get();
assertEquals(expectedAlgorithm.getName(), matchingValue.getDisplayName());
assertEquals(expectedAlgorithm.buildAllowableValueDescription(), matchingValue.getDescription());
});
}
@Test
void testShouldBuildCharacterSetAllowableValues() throws Exception {
// Arrange
final List<Charset> EXPECTED_CHARACTER_SETS = Arrays.asList(
StandardCharsets.US_ASCII,
StandardCharsets.ISO_8859_1,
StandardCharsets.UTF_8,
StandardCharsets.UTF_16BE,
StandardCharsets.UTF_16LE,
StandardCharsets.UTF_16
);
final Map<String, String> expectedDescriptions = Collections.singletonMap(
"UTF-16",
"This character set normally decodes using an optional BOM at the beginning of the data but encodes by inserting a BE BOM. For hashing, it will be replaced with UTF-16BE. "
);
// Act
final AllowableValue[] allowableValues = HashService.buildCharacterSetAllowableValues();
// Assert
assertInstanceOf(AllowableValue[].class, allowableValues);
final List<AllowableValue> valuesList = Arrays.asList(allowableValues);
assertEquals(EXPECTED_CHARACTER_SETS.size(), valuesList.size());
EXPECTED_CHARACTER_SETS.forEach(charset -> {
final AllowableValue matchingValue = valuesList
.stream()
.filter(value -> value.getValue() == charset.name())
.findFirst()
.get();
assertEquals(charset.name(), matchingValue.getDisplayName());
assertEquals((expectedDescriptions.containsKey(charset.name()) ? expectedDescriptions.get(charset.name()) : charset.displayName()), matchingValue.getDescription());
});
}
@Test
void testShouldHashValueFromStream() throws Exception {
// Arrange
// No command-line md2sum tool available
final List<HashAlgorithm> algorithms = new ArrayList<>(Arrays.asList(HashAlgorithm.values()));
algorithms.remove(HashAlgorithm.MD2);
StringBuilder sb = new StringBuilder();
final int times = 10000;
for (int i = 0; i < times; i++) {
sb.append(String.format("%s: %s\n", StringUtils.leftPad(String.valueOf(i), 5), StringUtils.repeat("apachenifi ", 10)));
}
/* These values were generated using command-line tools (openssl dgst -md5, shasum [-a 1 224 256 384 512 512224 512256], rhash --sha3-224, b2sum -l 160)
* Ex: {@code $ openssl dgst -md5 src/test/resources/HashServiceTest/largefile.txt}
*/
final Map<String, String> EXPECTED_HASHES = new HashMap<>();
EXPECTED_HASHES.put("md5", "8d329076847b678449610a5fb53997d2");
EXPECTED_HASHES.put("sha_1", "09cd981ee7529cfd6268a69c0d53e8117e9c78b1");
EXPECTED_HASHES.put("sha_224", "4d4d58c226959e0775e627a866eaa26bf18121d578b559946aea6f8c");
EXPECTED_HASHES.put("sha_256", "ce50f183a8011a86c5162e94481c6b14ad921a8001746806063b3033e71440eb");
EXPECTED_HASHES.put("sha_384", "62a13a410566856422f0b81b2e6ab26f91b3da1a877a5c24f681d2812f26abbc43fb637954879915b3cd9aad626ca71c");
EXPECTED_HASHES.put("sha_512", "3f036116c78b1d9e2017bb1fd4b04f449839e6434c94442edebffdcdfbac1d79b483978126f0ffb12824f14ecc36a07dc95f0ba04aa68885456f3f6381471e07");
EXPECTED_HASHES.put("sha_512_224", "aa7227a80889366a2325801a5cfa67f29c8f272f4284aecfe5daba3c");
EXPECTED_HASHES.put("sha_512_256", "76faa424ee31bcb1f3a41a848806e288cb064a6bf1867881ee1b439dd8b38e40");
EXPECTED_HASHES.put("sha3_224", "d4bb36bf2d00117ade2e63c6fa2ef5f6714d8b6c7a40d12623f95fd0");
EXPECTED_HASHES.put("sha3_256", "f93ff4178bc7f466444a822191e152332331ba51eee42b952b3be1b46b1921f7");
EXPECTED_HASHES.put("sha3_384", "7e4dfb0073645f059e5837f7c066bffd7f8b5d888b0179a8f0be6bb11c7d631847c468d4d861abcdc96503d91f2a7a78");
EXPECTED_HASHES.put("sha3_512", "bf8e83f3590727e04777406e1d478615cf68468ad8690dba3f22a879e08022864a2b4ad8e8a1cbc88737578abd4b2e8493e3bda39a81af3f21fc529c1a7e3b52");
EXPECTED_HASHES.put("blake2_160", "71dd4324a1f72aa10aaa59ee4d79ceee8d8915e6");
EXPECTED_HASHES.put("blake2_256", "5a25864c69f42adeefc343989babb6972df38da47bb6ce712fbef4474266b539");
EXPECTED_HASHES.put("blake2_384", "52417243317ca01693ba835bd5d6655c73a2f70d811b4d26ddacf9e3b74fc3993f30adc64fb6c23a6a5c1e36771a0b95");
EXPECTED_HASHES.put("blake2_512", "be81dbc396a9e11c6189d2408a956466fb1c784d2d34495f9ca43434041b425675005deaeea1a04b1f44db0200b19cde5a40fd5e88414bb300620bc3d5e30f6a");
// Act
final Map<String, String> generatedHashes = algorithms
.stream()
.collect(Collectors.toMap(HashAlgorithm::getName, algorithm -> {
// Get a new InputStream for each iteration, or it will calculate the hash of an empty input on iterations 1 - n
InputStream input = new ByteArrayInputStream(sb.toString().getBytes());
try {
return HashService.hashValueStreaming(algorithm, input);
} catch (IOException e) {
throw new RuntimeException(e);
}
}));
// Assert
for (final Map.Entry<String, String> entry : generatedHashes.entrySet()) {
final String algorithmName = entry.getKey();
final String hash = entry.getValue();
String key = translateStringToMapKey(algorithmName);
assertEquals(EXPECTED_HASHES.get(key), hash);
}
}
private static String translateStringToMapKey(String string) {
return string.toLowerCase().replaceAll("[-\\/]", "_");
}
}

View File

@ -0,0 +1,214 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import org.apache.commons.codec.binary.Hex;
import org.apache.nifi.security.util.EncryptionMethod;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.PBEParameterSpec;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class NiFiLegacyCipherProviderTest {
private static List<EncryptionMethod> pbeEncryptionMethods = new ArrayList<>();
private static List<EncryptionMethod> limitedStrengthPbeEncryptionMethods = new ArrayList<>();
private static final String PROVIDER_NAME = "BC";
private static final int ITERATION_COUNT = 1000;
private static final String SHORT_PASSWORD = "shortPassword";
private static byte[] SALT_16_BYTES;
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider());
SALT_16_BYTES = Hex.decodeHex("aabbccddeeff00112233445566778899".toCharArray());
pbeEncryptionMethods = Arrays.stream(EncryptionMethod.values())
.filter(encryptionMethod -> encryptionMethod.getAlgorithm().toUpperCase().startsWith("PBE"))
.collect(Collectors.toList());
limitedStrengthPbeEncryptionMethods = pbeEncryptionMethods.stream()
.filter(encryptionMethod -> !encryptionMethod.isUnlimitedStrength())
.collect(Collectors.toList());
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
NiFiLegacyCipherProvider cipherProvider = new NiFiLegacyCipherProvider();
final String plaintext = "This is a plaintext message.";
// Act
for (EncryptionMethod encryptionMethod : limitedStrengthPbeEncryptionMethods) {
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(SHORT_PASSWORD.length(), encryptionMethod)) {
continue;
}
byte[] salt = cipherProvider.generateSalt(encryptionMethod);
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, salt, true);
byte[] cipherBytes = cipher.doFinal(plaintext.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, salt, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(plaintext, recovered);
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
NiFiLegacyCipherProvider cipherProvider = new NiFiLegacyCipherProvider();
final String plaintext = "This is a plaintext message.";
// Act
for (EncryptionMethod encryptionMethod : pbeEncryptionMethods) {
byte[] salt = cipherProvider.generateSalt(encryptionMethod);
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, salt, true);
byte[] cipherBytes = cipher.doFinal(plaintext.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, salt, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(plaintext, recovered);
}
}
@Test
void testGetCipherShouldSupportLegacyCode() throws Exception {
// Arrange
NiFiLegacyCipherProvider cipherProvider = new NiFiLegacyCipherProvider();
final String plaintext = "This is a plaintext message.";
// Act
for (EncryptionMethod encryptionMethod : limitedStrengthPbeEncryptionMethods) {
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(SHORT_PASSWORD.length(), encryptionMethod)) {
continue;
}
byte[] salt = cipherProvider.generateSalt(encryptionMethod);
// Initialize a legacy cipher for encryption
Cipher legacyCipher = getLegacyCipher(SHORT_PASSWORD, salt, encryptionMethod.getAlgorithm());
byte[] cipherBytes = legacyCipher.doFinal(plaintext.getBytes("UTF-8"));
Cipher providedCipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, salt, false);
byte[] recoveredBytes = providedCipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(plaintext, recovered);
}
}
@Test
void testGetCipherWithoutSaltShouldSupportLegacyCode() throws Exception {
// Arrange
NiFiLegacyCipherProvider cipherProvider = new NiFiLegacyCipherProvider();
final byte[] SALT = new byte[0];
final String plaintext = "This is a plaintext message.";
// Act
for (EncryptionMethod em : limitedStrengthPbeEncryptionMethods) {
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(SHORT_PASSWORD.length(), em)) {
continue;
}
// Initialize a legacy cipher for encryption
Cipher legacyCipher = getLegacyCipher(SHORT_PASSWORD, SALT, em.getAlgorithm());
byte[] cipherBytes = legacyCipher.doFinal(plaintext.getBytes("UTF-8"));
Cipher providedCipher = cipherProvider.getCipher(em, SHORT_PASSWORD, false);
byte[] recoveredBytes = providedCipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(plaintext, recovered);
}
}
@Test
void testGetCipherShouldIgnoreKeyLength() throws Exception {
// Arrange
NiFiLegacyCipherProvider cipherProvider = new NiFiLegacyCipherProvider();
final byte[] SALT = SALT_16_BYTES;
final String plaintext = "This is a plaintext message.";
final List<Integer> KEY_LENGTHS = Arrays.asList(-1, 40, 64, 128, 192, 256);
// Initialize a cipher for encryption
EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES;
final Cipher cipher128 = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, true);
byte[] cipherBytes = cipher128.doFinal(plaintext.getBytes("UTF-8"));
// Act
for (final int keyLength : KEY_LENGTHS) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, keyLength, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(plaintext, recovered);
}
}
private static Cipher getLegacyCipher(String password, byte[] salt, String algorithm) {
try {
final PBEKeySpec pbeKeySpec = new PBEKeySpec(password.toCharArray());
final SecretKeyFactory factory = SecretKeyFactory.getInstance(algorithm, PROVIDER_NAME);
SecretKey tempKey = factory.generateSecret(pbeKeySpec);
final PBEParameterSpec parameterSpec = new PBEParameterSpec(salt, ITERATION_COUNT);
Cipher cipher = Cipher.getInstance(algorithm, PROVIDER_NAME);
cipher.init(Cipher.ENCRYPT_MODE, tempKey, parameterSpec);
return cipher;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,272 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import org.apache.commons.codec.binary.Hex;
import org.apache.nifi.security.util.EncryptionMethod;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.PBEParameterSpec;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class OpenSSLPKCS5CipherProviderTest {
private static List<EncryptionMethod> pbeEncryptionMethods = new ArrayList<>();
private static List<EncryptionMethod> limitedStrengthPbeEncryptionMethods = new ArrayList<>();
private static final String PROVIDER_NAME = "BC";
private static final int ITERATION_COUNT = 0;
private static final String SHORT_PASSWORD = "shortPassword";
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider());
pbeEncryptionMethods = Arrays.stream(EncryptionMethod.values())
.filter(encryptionMethod -> encryptionMethod.getAlgorithm().toUpperCase().startsWith("PBE"))
.collect(Collectors.toList());
limitedStrengthPbeEncryptionMethods = pbeEncryptionMethods.stream()
.filter(encryptionMethod -> !encryptionMethod.isUnlimitedStrength())
.collect(Collectors.toList());
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider();
final byte[] SALT = Hex.decodeHex("aabbccddeeff0011".toCharArray());
final String plaintext = "This is a plaintext message.";
// Act
for (EncryptionMethod em : limitedStrengthPbeEncryptionMethods) {
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(SHORT_PASSWORD.length(), em)) {
continue;
}
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, true);
byte[] cipherBytes = cipher.doFinal(plaintext.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(plaintext, recovered);
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider();
final byte[] SALT = Hex.decodeHex("aabbccddeeff0011".toCharArray());
final String plaintext = "This is a plaintext message.";
// Act
for (EncryptionMethod em : pbeEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, true);
byte[] cipherBytes = cipher.doFinal(plaintext.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(plaintext, recovered);
}
}
@Test
void testGetCipherShouldSupportLegacyCode() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider();
final byte[] SALT = Hex.decodeHex("0011223344556677".toCharArray());
final String plaintext = "This is a plaintext message.";
// Act
for (EncryptionMethod em : limitedStrengthPbeEncryptionMethods) {
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(SHORT_PASSWORD.length(), em)) {
continue;
}
// Initialize a legacy cipher for encryption
Cipher legacyCipher = getLegacyCipher(SHORT_PASSWORD, SALT, em.getAlgorithm());
byte[] cipherBytes = legacyCipher.doFinal(plaintext.getBytes("UTF-8"));
Cipher providedCipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, false);
byte[] recoveredBytes = providedCipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(plaintext, recovered);
}
}
@Test
void testGetCipherWithoutSaltShouldSupportLegacyCode() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider();
final byte[] SALT = new byte[0];
final String plaintext = "This is a plaintext message.";
// Act
for (EncryptionMethod em : limitedStrengthPbeEncryptionMethods) {
if (!CipherUtility.passwordLengthIsValidForAlgorithmOnLimitedStrengthCrypto(SHORT_PASSWORD.length(), em)) {
continue;
}
// Initialize a legacy cipher for encryption
Cipher legacyCipher = getLegacyCipher(SHORT_PASSWORD, SALT, em.getAlgorithm());
byte[] cipherBytes = legacyCipher.doFinal(plaintext.getBytes("UTF-8"));
Cipher providedCipher = cipherProvider.getCipher(em, SHORT_PASSWORD, false);
byte[] recoveredBytes = providedCipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(plaintext, recovered);
}
}
@Test
void testGetCipherShouldIgnoreKeyLength() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider();
final byte[] SALT = Hex.decodeHex("aabbccddeeff0011".toCharArray());
final String plaintext = "This is a plaintext message.";
final List<Integer> KEY_LENGTHS = Arrays.asList(-1, 40, 64, 128, 192, 256);
// Initialize a cipher for encryption
EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES;
final Cipher cipher128 = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, true);
byte[] cipherBytes = cipher128.doFinal(plaintext.getBytes("UTF-8"));
// Act
for (final int keyLength : KEY_LENGTHS) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, keyLength, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(plaintext, recovered);
}
}
@Test
void testGetCipherShouldRequireEncryptionMethod() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider();
final byte[] SALT = Hex.decodeHex("0011223344556677".toCharArray());
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(null, SHORT_PASSWORD, SALT, false));
// Assert
assertTrue(iae.getMessage().contains("The encryption method must be specified"));
}
@Test
void testGetCipherShouldRequirePassword() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider();
final byte[] SALT = Hex.decodeHex("0011223344556677".toCharArray());
EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES;
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, "", SALT, false));
// Assert
assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"));
}
@Test
void testGetCipherShouldValidateSaltLength() throws Exception {
// Arrange
OpenSSLPKCS5CipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider();
final byte[] SALT = Hex.decodeHex("00112233445566".toCharArray());
EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES;
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, false));
// Assert
assertTrue(iae.getMessage().contains("Salt must be 8 bytes US-ASCII encoded"));
}
@Test
void testGenerateSaltShouldProvideValidSalt() throws Exception {
// Arrange
PBECipherProvider cipherProvider = new OpenSSLPKCS5CipherProvider();
// Act
byte[] salt = cipherProvider.generateSalt();
// Assert
assertEquals(cipherProvider.getDefaultSaltLength(), salt.length);
byte[] notExpected = new byte[cipherProvider.getDefaultSaltLength()];
Arrays.fill(notExpected, (byte) 0x00);
assertFalse(Arrays.equals(notExpected, salt));
}
private static Cipher getLegacyCipher(String password, byte[] salt, String algorithm) throws Exception {
final PBEKeySpec pbeKeySpec = new PBEKeySpec(password.toCharArray());
final SecretKeyFactory factory = SecretKeyFactory.getInstance(algorithm, PROVIDER_NAME);
SecretKey tempKey = factory.generateSecret(pbeKeySpec);
final PBEParameterSpec parameterSpec = new PBEParameterSpec(salt, ITERATION_COUNT);
Cipher cipher = Cipher.getInstance(algorithm, PROVIDER_NAME);
cipher.init(Cipher.ENCRYPT_MODE, tempKey, parameterSpec);
return cipher;
}
}

View File

@ -0,0 +1,484 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import org.apache.commons.codec.binary.Hex;
import org.apache.nifi.security.util.EncryptionMethod;
import org.apache.nifi.util.StringUtils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import javax.crypto.Cipher;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class PBKDF2CipherProviderTest {
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess";
private static final String SHORT_PASSWORD = "shortPassword";
private static final String BAD_PASSWORD = "thisIsABadPassword";
private static List<EncryptionMethod> strongKDFEncryptionMethods;
public static final String MICROBENCHMARK = "microbenchmark";
private static final int DEFAULT_KEY_LENGTH = 128;
private static final int TEST_ITERATION_COUNT = 1000;
private final String DEFAULT_PRF = "SHA-512";
private final String SALT_HEX = "0123456789ABCDEFFEDCBA9876543210";
private final String IV_HEX = StringUtils.repeat("01", 16);
private static List<Integer> AES_KEY_LENGTHS;
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider());
strongKDFEncryptionMethods = Arrays.stream(EncryptionMethod.values())
.filter(EncryptionMethod::isCompatibleWithStrongKDFs)
.collect(Collectors.toList());
AES_KEY_LENGTHS = Arrays.asList(128, 192, 256);
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT);
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, DEFAULT_KEY_LENGTH, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherShouldRejectInvalidIV() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT);
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
final int MAX_LENGTH = 15;
final List<byte[]> INVALID_IVS = new ArrayList<>();
for (int length = 0; length <= MAX_LENGTH; length++) {
INVALID_IVS.add(new byte[length]);
}
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final byte[] badIV : INVALID_IVS) {
// Encrypt should print a warning about the bad IV but overwrite it
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, true);
// Decrypt should fail
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false));
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"));
}
}
@Test
void testGetCipherWithExternalIVShouldBeInternallyConsistent() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT);
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
final byte[] IV = Hex.decodeHex(IV_HEX.toCharArray());
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT);
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
final int LONG_KEY_LENGTH = 256;
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, LONG_KEY_LENGTH, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, iv, LONG_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testShouldRejectEmptyPRF() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider;
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
final byte[] IV = Hex.decodeHex(IV_HEX.toCharArray());
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
String prf = "";
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> new PBKDF2CipherProvider(prf, TEST_ITERATION_COUNT));
// Assert
assertTrue(iae.getMessage().contains("Cannot resolve empty PRF"));
}
@Test
void testShouldResolveDefaultPRF() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider;
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
final byte[] IV = Hex.decodeHex(IV_HEX.toCharArray());
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
final PBKDF2CipherProvider SHA512_PROVIDER = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT);
String prf = "sha768";
// Act
cipherProvider = new PBKDF2CipherProvider(prf, TEST_ITERATION_COUNT);
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = SHA512_PROVIDER.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
@Test
void testShouldResolveVariousPRFs() throws Exception {
// Arrange
final List<String> PRFS = Arrays.asList("SHA-1", "MD5", "SHA-256", "SHA-384", "SHA-512");
RandomIVPBECipherProvider cipherProvider;
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
final byte[] IV = Hex.decodeHex(IV_HEX.toCharArray());
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final String prf : PRFS) {
cipherProvider = new PBKDF2CipherProvider(prf, TEST_ITERATION_COUNT);
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherShouldSupportExternalCompatibility() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider("SHA-256", TEST_ITERATION_COUNT);
final String PLAINTEXT = "This is a plaintext message.";
// These values can be generated by running `$ ./openssl_pbkdf2.rb` in the terminal
final byte[] SALT = Hex.decodeHex("ae2481bee3d8b5d5b732bf464ea2ff01".toCharArray());
final byte[] IV = Hex.decodeHex("26db997dcd18472efd74dabe5ff36853".toCharArray());
final String CIPHER_TEXT = "92edbabae06add6275a1d64815755a9ba52afc96e2c1a316d3abbe1826e96f6c";
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT.toCharArray());
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
@Test
void testGetCipherShouldHandleDifferentPRFs() throws Exception {
// Arrange
RandomIVPBECipherProvider sha256CP = new PBKDF2CipherProvider("SHA-256", TEST_ITERATION_COUNT);
RandomIVPBECipherProvider sha512CP = new PBKDF2CipherProvider("SHA-512", TEST_ITERATION_COUNT);
final String BAD_PASSWORD = "thisIsABadPassword";
final byte[] SALT = new byte[16];
Arrays.fill(SALT, (byte) 0x11);
final byte[] IV = new byte[16];
Arrays.fill(IV, (byte) 0x22);
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
Cipher sha256Cipher = sha256CP.getCipher(encryptionMethod, BAD_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] sha256CipherBytes = sha256Cipher.doFinal(PLAINTEXT.getBytes());
Cipher sha512Cipher = sha512CP.getCipher(encryptionMethod, BAD_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] sha512CipherBytes = sha512Cipher.doFinal(PLAINTEXT.getBytes());
// Assert
assertFalse(Arrays.equals(sha512CipherBytes, sha256CipherBytes));
Cipher sha256DecryptCipher = sha256CP.getCipher(encryptionMethod, BAD_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false);
byte[] sha256RecoveredBytes = sha256DecryptCipher.doFinal(sha256CipherBytes);
assertArrayEquals(PLAINTEXT.getBytes(), sha256RecoveredBytes);
Cipher sha512DecryptCipher = sha512CP.getCipher(encryptionMethod, BAD_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false);
byte[] sha512RecoveredBytes = sha512DecryptCipher.doFinal(sha512CipherBytes);
assertArrayEquals(PLAINTEXT.getBytes(), sha512RecoveredBytes);
}
@Test
void testGetCipherForDecryptShouldRequireIV() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT);
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
final byte[] IV = Hex.decodeHex(IV_HEX.toCharArray());
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, DEFAULT_KEY_LENGTH, false));
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"));
}
}
@Test
void testGetCipherShouldRejectInvalidSalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT);
final List<String> INVALID_SALTS = Arrays.asList("pbkdf2", "$3a$11$", "x", "$2a$10$", "", null);
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final String salt : INVALID_SALTS) {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, salt != null ? salt.getBytes(): null, DEFAULT_KEY_LENGTH, true));
// Assert
assertTrue(iae.getMessage().contains("The salt must be at least 16 bytes. To generate a salt, use PBKDF2CipherProvider#generateSalt"));
}
}
@Test
void testGetCipherShouldAcceptValidKeyLengths() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT);
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
final byte[] IV = Hex.decodeHex(IV_HEX.toCharArray());
// Currently only AES ciphers are compatible with PBKDF2, so redundant to test all algorithms
final List<Integer> VALID_KEY_LENGTHS = AES_KEY_LENGTHS;
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final int keyLength : VALID_KEY_LENGTHS) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherShouldNotAcceptInvalidKeyLengths() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT);
final byte[] SALT = Hex.decodeHex(SALT_HEX.toCharArray());
final byte[] IV = Hex.decodeHex(IV_HEX.toCharArray());
// Currently only AES ciphers are compatible with PBKDF2, so redundant to test all algorithms
final List<Integer> VALID_KEY_LENGTHS = Arrays.asList(-1, 40, 64, 112, 512);
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final int keyLength : VALID_KEY_LENGTHS) {
// Initialize a cipher for encryption
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, true));
// Assert
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"));
}
}
@EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true")
@Test
void testDefaultConstructorShouldProvideStrongIterationCount() throws Exception {
// Arrange
PBKDF2CipherProvider cipherProvider = new PBKDF2CipherProvider();
// Values taken from http://wildlyinaccurate.com/bcrypt-choosing-a-work-factor/ and http://security.stackexchange.com/questions/17207/recommended-of-rounds-for-bcrypt
// Calculate the iteration count to reach 500 ms
int minimumIterationCount = calculateMinimumIterationCount();
// Act
int iterationCount = cipherProvider.getIterationCount();
// Assert
assertTrue(iterationCount >= minimumIterationCount, "The default iteration count for PBKDF2CipherProvider is too weak. Please update the default value to a stronger level.");
}
/**
* Returns the iteration count required for a derivation to exceed 500 ms on this machine using the default PRF.
* Code adapted from http://security.stackexchange.com/questions/17207/recommended-of-rounds-for-bcrypt
*
* @return the minimum iteration count
*/
private static int calculateMinimumIterationCount() throws Exception {
// High start-up cost, so run multiple times for better benchmarking
final int RUNS = 10;
// Benchmark using an iteration count of 10k
int iterationCount = 10_000;
final byte[] SALT = new byte[16];
Arrays.fill(SALT, (byte) 0x00);
final byte[] IV = new byte[16];
Arrays.fill(IV, (byte) 0x01);
String defaultPrf = new PBKDF2CipherProvider().getPRFName();
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(defaultPrf, iterationCount);
long start;
long end;
double duration;
// Run once to prime the system
start = System.nanoTime();
cipherProvider.getCipher(EncryptionMethod.AES_CBC, MICROBENCHMARK, SALT, IV, DEFAULT_KEY_LENGTH, false);
end = System.nanoTime();
getTime(start, end);
final List<Double> durations = new ArrayList<>();
for (int i = 0; i < RUNS; i++) {
start = System.nanoTime();
cipherProvider.getCipher(EncryptionMethod.AES_CBC, String.format("%s%s", MICROBENCHMARK, i), SALT, IV, DEFAULT_KEY_LENGTH, false);
end = System.nanoTime();
duration = getTime(start, end);
durations.add(duration);
}
duration = durations.stream().mapToDouble(Double::doubleValue).sum() / durations.size();
// Keep increasing iteration count until the estimated duration is over 500 ms
while (duration < 500) {
iterationCount *= 2;
duration *= 2;
}
return iterationCount;
}
private static double getTime(final long start, final long end) {
return (end - start) / 1_000_000.0;
}
@Test
void testGenerateSaltShouldProvideValidSalt() throws Exception {
// Arrange
RandomIVPBECipherProvider cipherProvider = new PBKDF2CipherProvider(DEFAULT_PRF, TEST_ITERATION_COUNT);
// Act
byte[] salt = cipherProvider.generateSalt();
// Assert
assertEquals(16, salt.length);
byte[] notExpected = new byte[16];
Arrays.fill(notExpected, (byte) 0x00);
assertFalse(Arrays.equals(notExpected, salt));
}
}

View File

@ -14,199 +14,203 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package org.apache.nifi.security.util.crypto package org.apache.nifi.security.util.crypto;
import org.bouncycastle.util.encoders.Hex import org.bouncycastle.util.encoders.Hex;
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfSystemProperty import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets;
import java.util.stream.Collectors import java.util.ArrayList;
import java.util.stream.Stream import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertArrayEquals import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue import static org.junit.jupiter.api.Assertions.assertTrue;
class PBKDF2SecureHasherTest { public class PBKDF2SecureHasherTest {
private static final byte[] STATIC_SALT = "NiFi Static Salt".getBytes(StandardCharsets.UTF_8);
@Test @Test
void testShouldBeDeterministicWithStaticSalt() { void testShouldBeDeterministicWithStaticSalt() {
// Arrange // Arrange
int cost = 10_000 int cost = 10_000;
int dkLength = 32 int dkLength = 32;
byte[] inputBytes = "This is a sensitive value".bytes byte[] inputBytes = "This is a sensitive value".getBytes();
final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511" final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511";
// Act // Act
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(cost, dkLength) PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(cost, dkLength);
List<String> results = Stream.iterate(0, n -> n + 1) List<String> results = Stream.iterate(0, n -> n + 1)
.limit(10) .limit(10)
.map(iteration -> { .map(iteration -> {
byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes) byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes);
return new String(Hex.encode(hash)) return new String(Hex.encode(hash));
}) })
.collect(Collectors.toList()) .collect(Collectors.toList());
// Assert // Assert
results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result)) results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result));
} }
@Test @Test
void testShouldBeDifferentWithRandomSalt() { void testShouldBeDifferentWithRandomSalt() {
// Arrange // Arrange
String prf = "SHA512" String prf = "SHA512";
int cost = 10_000 int cost = 10_000;
int saltLength = 16 int saltLength = 16;
int dkLength = 32 int dkLength = 32;
byte[] inputBytes = "This is a sensitive value".bytes byte[] inputBytes = "This is a sensitive value".getBytes();
final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511" final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511";
//Act //Act
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(prf, cost, saltLength, dkLength) PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(prf, cost, saltLength, dkLength);
List<String> results = Stream.iterate(0, n -> n + 1) List<String> results = Stream.iterate(0, n -> n + 1)
.limit(10) .limit(10)
.map(iteration -> { .map(iteration -> {
byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes) byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes);
return new String(Hex.encode(hash)) return new String(Hex.encode(hash));
}) })
.collect(Collectors.toList()) .collect(Collectors.toList());
// Assert // Assert
assertEquals(results.unique().size(), results.size()) assertEquals(results.stream().distinct().collect(Collectors.toList()).size(), results.size());
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result)) results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result));
} }
@Test @Test
void testShouldHandleArbitrarySalt() { void testShouldHandleArbitrarySalt() {
// Arrange // Arrange
String prf = "SHA512" String prf = "SHA512";
int cost = 10_000 int cost = 10_000;
int saltLength = 16 int saltLength = 16;
int dkLength = 32 int dkLength = 32;
def input = "This is a sensitive value" final String input = "This is a sensitive value";
byte[] inputBytes = input.bytes byte[] inputBytes = input.getBytes();
final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511" final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511";
final String EXPECTED_HASH_BASE64 = "LEem2AG3Hgh/lHkgecQIgK6ikBO//9CrlLG8ES6lJRE" final String EXPECTED_HASH_BASE64 = "LEem2AG3Hgh/lHkgecQIgK6ikBO//9CrlLG8ES6lJRE";
final byte[] EXPECTED_HASH_BYTES = Hex.decode(EXPECTED_HASH_HEX) final byte[] EXPECTED_HASH_BYTES = Hex.decode(EXPECTED_HASH_HEX);
PBKDF2SecureHasher staticSaltHasher = new PBKDF2SecureHasher(cost, dkLength) PBKDF2SecureHasher staticSaltHasher = new PBKDF2SecureHasher(cost, dkLength);
PBKDF2SecureHasher arbitrarySaltHasher = new PBKDF2SecureHasher(prf, cost, saltLength, dkLength) PBKDF2SecureHasher arbitrarySaltHasher = new PBKDF2SecureHasher(prf, cost, saltLength, dkLength);
final byte[] STATIC_SALT = AbstractSecureHasher.STATIC_SALT final String DIFFERENT_STATIC_SALT = "Diff Static Salt";
final String DIFFERENT_STATIC_SALT = "Diff Static Salt"
// Act // Act
byte[] staticSaltHash = staticSaltHasher.hashRaw(inputBytes) byte[] staticSaltHash = staticSaltHasher.hashRaw(inputBytes);
byte[] arbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, STATIC_SALT) byte[] arbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, STATIC_SALT);
byte[] differentArbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, DIFFERENT_STATIC_SALT.getBytes(StandardCharsets.UTF_8)) byte[] differentArbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, DIFFERENT_STATIC_SALT.getBytes(StandardCharsets.UTF_8));
byte[] differentSaltHash = arbitrarySaltHasher.hashRaw(inputBytes) byte[] differentSaltHash = arbitrarySaltHasher.hashRaw(inputBytes);
String staticSaltHashHex = staticSaltHasher.hashHex(input) String staticSaltHashHex = staticSaltHasher.hashHex(input);
String arbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8)) String arbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8));
String differentArbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, DIFFERENT_STATIC_SALT) String differentArbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, DIFFERENT_STATIC_SALT);
String differentSaltHashHex = arbitrarySaltHasher.hashHex(input) String differentSaltHashHex = arbitrarySaltHasher.hashHex(input);
String staticSaltHashBase64 = staticSaltHasher.hashBase64(input) String staticSaltHashBase64 = staticSaltHasher.hashBase64(input);
String arbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8)) String arbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8));
String differentArbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, DIFFERENT_STATIC_SALT) String differentArbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, DIFFERENT_STATIC_SALT);
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input) String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input);
// Assert // Assert
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash) assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash);
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash) assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash);
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash)) assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash));
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash)) assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash));
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex) assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex);
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex) assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex);
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex) assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex);
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex) assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex);
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64) assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64);
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64) assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64);
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64) assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64);
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64) assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64);
} }
@Test @Test
void testShouldValidateArbitrarySalt() { void testShouldValidateArbitrarySalt() {
// Assert // Assert
String prf = "SHA512" String prf = "SHA512";
int cost = 10_000 int cost = 10_000;
int saltLength = 16 int saltLength = 16;
int dkLength = 32 int dkLength = 32;
def input = "This is a sensitive value" final String input = "This is a sensitive value";
byte[] inputBytes = input.bytes byte[] inputBytes = input.getBytes();
// Static salt instance // Static salt instance
PBKDF2SecureHasher secureHasher = new PBKDF2SecureHasher(prf, cost, saltLength, dkLength) PBKDF2SecureHasher secureHasher = new PBKDF2SecureHasher(prf, cost, saltLength, dkLength);
byte[] STATIC_SALT = "bad_sal".bytes byte[] STATIC_SALT = "bad_sal".getBytes();
assertThrows(IllegalArgumentException.class, { -> new PBKDF2SecureHasher(prf, cost, 7, dkLength) }) assertThrows(IllegalArgumentException.class, () -> new PBKDF2SecureHasher(prf, cost, 7, dkLength));
assertThrows(RuntimeException.class, { -> secureHasher.hashRaw(inputBytes, STATIC_SALT) }) assertThrows(RuntimeException.class, () -> secureHasher.hashRaw(inputBytes, STATIC_SALT));
assertThrows(RuntimeException.class, { -> secureHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8)) }) assertThrows(RuntimeException.class, () -> secureHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8)));
assertThrows(RuntimeException.class, { -> secureHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8)) }) assertThrows(RuntimeException.class, () -> secureHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8)));
} }
@Test @Test
void testShouldFormatHex() { void testShouldFormatHex() {
// Arrange // Arrange
String input = "This is a sensitive value" String input = "This is a sensitive value";
final String EXPECTED_HASH_HEX = "8f67110e87d225366e2d79ad251d2cf48f8cb15845800452e0e2cff09f95ef1c" final String EXPECTED_HASH_HEX = "8f67110e87d225366e2d79ad251d2cf48f8cb15845800452e0e2cff09f95ef1c";
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher() PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher();
// Act // Act
String hashHex = pbkdf2SecureHasher.hashHex(input) String hashHex = pbkdf2SecureHasher.hashHex(input);
// Assert // Assert
assertEquals(EXPECTED_HASH_HEX, hashHex) assertEquals(EXPECTED_HASH_HEX, hashHex);
} }
@Test @Test
void testShouldFormatBase64() { void testShouldFormatBase64() {
// Arrange // Arrange
String input = "This is a sensitive value" String input = "This is a sensitive value";
final String EXPECTED_HASH_BASE64 = "j2cRDofSJTZuLXmtJR0s9I+MsVhFgARS4OLP8J+V7xw" final String EXPECTED_HASH_BASE64 = "j2cRDofSJTZuLXmtJR0s9I+MsVhFgARS4OLP8J+V7xw";
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher() PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher();
// Act // Act
String hashB64 = pbkdf2SecureHasher.hashBase64(input) String hashB64 = pbkdf2SecureHasher.hashBase64(input);
// Assert // Assert
assertEquals(EXPECTED_HASH_BASE64, hashB64) assertEquals(EXPECTED_HASH_BASE64, hashB64);
} }
@Test @Test
void testShouldHandleNullInput() { void testShouldHandleNullInput() {
// Arrange // Arrange
List<String> inputs = [null, ""] List<String> inputs = Arrays.asList(null, "");
final String EXPECTED_HASH_HEX = "7f2d8d8c7aaa45471f6c05a8edfe0a3f75fe01478cc965c5dce664e2ac6f5d0a" final String EXPECTED_HASH_HEX = "7f2d8d8c7aaa45471f6c05a8edfe0a3f75fe01478cc965c5dce664e2ac6f5d0a";
final String EXPECTED_HASH_BASE64 = "fy2NjHqqRUcfbAWo7f4KP3X+AUeMyWXF3OZk4qxvXQo" final String EXPECTED_HASH_BASE64 = "fy2NjHqqRUcfbAWo7f4KP3X+AUeMyWXF3OZk4qxvXQo";
// Act // Act
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher() PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher();
List<String> hexResults = inputs.stream() List<String> hexResults = inputs.stream()
.map(input -> pbkdf2SecureHasher.hashHex(input)) .map(input -> pbkdf2SecureHasher.hashHex(input))
.collect(Collectors.toList()) .collect(Collectors.toList());
List<String> B64Results = inputs.stream() List<String> B64Results = inputs.stream()
.map(input -> pbkdf2SecureHasher.hashBase64(input)) .map(input -> pbkdf2SecureHasher.hashBase64(input))
.collect(Collectors.toList()) .collect(Collectors.toList());
// Assert // Assert
hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result)) hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result));
B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result)) B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result));
} }
/** /**
@ -217,106 +221,102 @@ class PBKDF2SecureHasherTest {
@Test @Test
void testDefaultCostParamsShouldBeSufficient() { void testDefaultCostParamsShouldBeSufficient() {
// Arrange // Arrange
int testIterations = 100 int testIterations = 100;
byte[] inputBytes = "This is a sensitive value".bytes byte[] inputBytes = "This is a sensitive value".getBytes();
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher() PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher();
def results = [] final List<String> results = new ArrayList<>();
def resultDurations = [] final List<Long> resultDurations = new ArrayList<>();
// Act // Act
testIterations.times { int i -> for (int i = 0; i < testIterations; i++) {
long startNanos = System.nanoTime() long startNanos = System.nanoTime();
byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes) byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes);
long endNanos = System.nanoTime() long endNanos = System.nanoTime();
long durationNanos = endNanos - startNanos long durationNanos = endNanos - startNanos;
String hashHex = Hex.encode(hash) String hashHex = Arrays.toString(Hex.encode(hash));
results << hashHex results.add(hashHex);
resultDurations << durationNanos resultDurations.add(durationNanos);
} }
// Assert // Assert
final long MIN_DURATION_NANOS = 75_000_000 // 75 ms final long MIN_DURATION_NANOS = 75_000_000; // 75 ms
assertTrue(resultDurations.min() > MIN_DURATION_NANOS) assertTrue(Collections.min(resultDurations) > MIN_DURATION_NANOS);
assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS) assertTrue(resultDurations.stream().mapToLong(Long::longValue).sum() / testIterations > MIN_DURATION_NANOS);
} }
@Test @Test
void testShouldVerifyIterationCountBoundary() throws Exception { void testShouldVerifyIterationCountBoundary() throws Exception {
// Arrange // Arrange
def validIterationCounts = [1, 1000, 1_000_000] final List<Integer> validIterationCounts = Arrays.asList(1, 1000, 1_000_000);
// Act // Act & Assert
def results = validIterationCounts.collect { int i -> for (final int iterationCount : validIterationCounts) {
boolean valid = PBKDF2SecureHasher.isIterationCountValid(i) assertTrue(PBKDF2SecureHasher.isIterationCountValid(iterationCount));
valid
} }
// Assert
assertTrue(results.every())
} }
@Test @Test
void testShouldFailIterationCountBoundary() throws Exception { void testShouldFailIterationCountBoundary() throws Exception {
// Arrange // Arrange
List<Integer> invalidIterationCounts = [-1, 0, Integer.MAX_VALUE + 1] List<Integer> invalidIterationCounts = Arrays.asList(-1, 0, Integer.MAX_VALUE + 1);
// Act and Assert // Act and Assert
invalidIterationCounts.forEach(i -> assertFalse(PBKDF2SecureHasher.isIterationCountValid(i))) invalidIterationCounts.forEach(i -> assertFalse(PBKDF2SecureHasher.isIterationCountValid(i)));
} }
@Test @Test
void testShouldVerifyDKLengthBoundary() throws Exception { void testShouldVerifyDKLengthBoundary() throws Exception {
// Arrange // Arrange
List<Integer> validHLengths = [32, 64] List<Integer> validHLengths = Arrays.asList(32, 64);
// 1 and MAX_VALUE are the length boundaries, inclusive // 1 and MAX_VALUE are the length boundaries, inclusive
List<Integer> validDKLengths = [1, 1000, 1_000_000, Integer.MAX_VALUE] List<Integer> validDKLengths = Arrays.asList(1, 1000, 1_000_000, Integer.MAX_VALUE);
// Act and Assert // Act and Assert
validHLengths.forEach(hLen -> { validHLengths.forEach(hLen -> {
validDKLengths.forEach(dkLength -> { validDKLengths.forEach(dkLength -> {
assertTrue(PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength)) assertTrue(PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength));
}) });
}) });
} }
@Test @Test
void testShouldFailDKLengthBoundary() throws Exception { void testShouldFailDKLengthBoundary() throws Exception {
// Arrange // Arrange
List<Integer> validHLengths = [32, 64] List<Integer> validHLengths = Arrays.asList(32, 64);
// MAX_VALUE + 1 will become MIN_VALUE because of signed integer math // MAX_VALUE + 1 will become MIN_VALUE because of signed integer math
List<Integer> invalidDKLengths = [-1, 0, Integer.MAX_VALUE + 1, new Integer(Integer.MAX_VALUE * 2 - 1)] List<Integer> invalidDKLengths = Arrays.asList(-1, 0, Integer.MAX_VALUE + 1, new Integer(Integer.MAX_VALUE * 2 - 1));
// Act and Assert // Act and Assert
validHLengths.forEach(hLen -> { validHLengths.forEach(hLen -> {
invalidDKLengths.forEach(dkLength -> { invalidDKLengths.forEach(dkLength -> {
assertFalse(PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength)) assertFalse(PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength));
}) });
}) });
} }
@Test @Test
void testShouldVerifySaltLengthBoundary() throws Exception { void testShouldVerifySaltLengthBoundary() throws Exception {
// Arrange // Arrange
List<Integer> saltLengths = [0, 16, 64] List<Integer> saltLengths = Arrays.asList(0, 16, 64);
// Act and Assert // Act and Assert
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher() PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher();
saltLengths.forEach(saltLength -> assertTrue(pbkdf2SecureHasher.isSaltLengthValid(saltLength))) saltLengths.forEach(saltLength -> assertTrue(pbkdf2SecureHasher.isSaltLengthValid(saltLength)));
} }
@Test @Test
void testShouldFailSaltLengthBoundary() throws Exception { void testShouldFailSaltLengthBoundary() throws Exception {
// Arrange // Arrange
List<Integer> saltLengths = [-8, 1, Integer.MAX_VALUE + 1] List<Integer> saltLengths = Arrays.asList(-8, 1, Integer.MAX_VALUE + 1);
// Act and Assert // Act and Assert
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher() PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher();
saltLengths.forEach(saltLength -> assertFalse(pbkdf2SecureHasher.isSaltLengthValid(saltLength))) saltLengths.forEach(saltLength -> assertFalse(pbkdf2SecureHasher.isSaltLengthValid(saltLength)));
} }
} }

View File

@ -0,0 +1,610 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.binary.Hex;
import org.apache.nifi.security.util.EncryptionMethod;
import org.apache.nifi.security.util.crypto.scrypt.Scrypt;
import org.apache.nifi.util.StringUtils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.security.SecureRandom;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class ScryptCipherProviderTest {
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess";
private static final String SHORT_PASSWORD = "shortPassword";
private static final String BAD_PASSWORD = "thisIsABadPassword";
private static List<EncryptionMethod> strongKDFEncryptionMethods;
private static final int DEFAULT_KEY_LENGTH = 128;
public static final String MICROBENCHMARK = "microbenchmark";
private static List<Integer> AES_KEY_LENGTHS;
RandomIVPBECipherProvider cipherProvider;
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider());
strongKDFEncryptionMethods = Arrays.stream(EncryptionMethod.values())
.filter(EncryptionMethod::isCompatibleWithStrongKDFs)
.collect(Collectors.toList());
AES_KEY_LENGTHS = Arrays.asList(128, 192, 256);
}
@BeforeEach
void setUp() throws Exception {
// Very fast parameters to test for correctness rather than production values
cipherProvider = new ScryptCipherProvider(4, 1, 1);
}
@Test
void testGetCipherShouldBeInternallyConsistent() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, DEFAULT_KEY_LENGTH, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, iv, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherWithExternalIVShouldBeInternallyConsistent() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("01", 16).toCharArray());
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherWithUnlimitedStrengthShouldBeInternallyConsistent() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
final int LONG_KEY_LENGTH = 256;
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, LONG_KEY_LENGTH, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, iv, LONG_KEY_LENGTH, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testScryptShouldSupportExternalCompatibility() throws Exception {
// Arrange
// Default values are N=2^14, r=8, p=1, but the provided salt will contain the parameters used
cipherProvider = new ScryptCipherProvider();
final String PLAINTEXT = "This is a plaintext message.";
final int DK_LEN = 128;
// These values can be generated by running `$ ./openssl_scrypt.rb` in the terminal
final byte[] SALT = Hex.decodeHex("f5b8056ea6e66edb8d013ac432aba24a".toCharArray());
final byte[] IV = Hex.decodeHex("76a00f00878b8c3db314ae67804c00a1".toCharArray());
final String CIPHER_TEXT = "604188bf8e9137bc1b24a0ab01973024bc5935e9ae5fedf617bdca028c63c261";
byte[] cipherBytes = Hex.decodeHex(CIPHER_TEXT.toCharArray());
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Sanity check
String rubyKeyHex = "a8efbc0a709d3f89b6bb35b05fc8edf5";
Cipher rubyCipher = Cipher.getInstance(encryptionMethod.getAlgorithm(), "BC");
final SecretKeySpec rubyKey = new SecretKeySpec(Hex.decodeHex(rubyKeyHex.toCharArray()), "AES");
final IvParameterSpec ivSpec = new IvParameterSpec(IV);
rubyCipher.init(Cipher.ENCRYPT_MODE, rubyKey, ivSpec);
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.getBytes());
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec);
assertArrayEquals(PLAINTEXT.getBytes(), rubyCipher.doFinal(rubyCipherBytes));
assertArrayEquals(PLAINTEXT.getBytes(), rubyCipher.doFinal(cipherBytes));
// n$r$p$hex_salt_SL$hex_hash_HL
final String FULL_HASH = "400$8$24$f5b8056ea6e66edb8d013ac432aba24a$a8efbc0a709d3f89b6bb35b05fc8edf5";
final String[] fullHashArr = FULL_HASH.split("\\$");
final String nStr = fullHashArr[0];
final String rStr = fullHashArr[1];
final String pStr = fullHashArr[2];
final String saltHex = fullHashArr[3];
final String hashHex = fullHashArr[4];
final int n = Integer.valueOf(nStr, 16);
final int r = Integer.valueOf(rStr, 16);
final int p = Integer.valueOf(pStr, 16);
// Form Java-style salt with cost params from Ruby-style
String javaSalt = Scrypt.formatSalt(Hex.decodeHex(saltHex.toCharArray()), n, r, p);
// Convert hash from hex to Base64
String base64Hash = CipherUtility.encodeBase64NoPadding(Hex.decodeHex(hashHex.toCharArray()));
assertEquals(hashHex, Hex.encodeHexString(Base64.decodeBase64(base64Hash)));
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, javaSalt.getBytes(), IV, DK_LEN, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
@Test
void testGetCipherShouldHandleSaltWithoutParameters() throws Exception {
// Arrange
// To help Groovy resolve implementation private methods not known at interface level
final ScryptCipherProvider cipherProvider = (ScryptCipherProvider) this.cipherProvider;
final byte[] SALT = new byte[cipherProvider.getDefaultSaltLength()];
new SecureRandom().nextBytes(SALT);
final String EXPECTED_FORMATTED_SALT = cipherProvider.formatSaltForScrypt(SALT);
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, DEFAULT_KEY_LENGTH, true);
byte[] iv = cipher.getIV();
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
// Manually initialize a cipher for decrypt with the expected salt
byte[] parsedSalt = new byte[cipherProvider.getDefaultSaltLength()];
final List<Integer> params = new ArrayList<>();
cipherProvider.parseSalt(EXPECTED_FORMATTED_SALT, parsedSalt, params);
final int n = params.get(0);
final int r = params.get(1);
final int p = params.get(2);
byte[] keyBytes = Scrypt.deriveScryptKey(SHORT_PASSWORD.getBytes(), parsedSalt, n, r, p, DEFAULT_KEY_LENGTH);
SecretKey key = new SecretKeySpec(keyBytes, "AES");
Cipher manualCipher = Cipher.getInstance(encryptionMethod.getAlgorithm(), encryptionMethod.getProvider());
manualCipher.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
byte[] recoveredBytes = manualCipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
@Test
void testGetCipherShouldNotAcceptInvalidSalts() throws Exception {
// Arrange
final List<String> INVALID_SALTS = Arrays.asList("bad_sal", "$3a$11$", "x", "$2a$10$");
final String LENGTH_MESSAGE = "The raw salt must be greater than or equal to 8 bytes";
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final String salt : INVALID_SALTS) {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, salt.getBytes(), DEFAULT_KEY_LENGTH, true));
// Assert
assertTrue(iae.getMessage().contains(LENGTH_MESSAGE));
}
}
@Test
void testGetCipherShouldHandleUnformattedSalts() throws Exception {
// Arrange
final List<String> RECOVERABLE_SALTS = Arrays.asList("$ab$00$acbdefghijklmnopqrstuv", "$4$1$1$0123456789abcdef", "$400$1$1$abcdefghijklmnopqrstuv");
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final String salt : RECOVERABLE_SALTS) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, salt.getBytes(), DEFAULT_KEY_LENGTH, true);
// Assert
assertNotNull(cipher);
}
}
@Test
void testGetCipherShouldRejectEmptySalt() throws Exception {
// Arrange
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, BAD_PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true));
// Assert
assertTrue(iae.getMessage().contains("The salt cannot be empty. To generate a salt, use ScryptCipherProvider#generateSalt"));
}
@Test
void testGetCipherForDecryptShouldRequireIV() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("00", 16).toCharArray());
// Act
for (EncryptionMethod em : strongKDFEncryptionMethods) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, SHORT_PASSWORD, SALT, DEFAULT_KEY_LENGTH, false));
// Assert
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"));
}
}
@Test
void testGetCipherShouldAcceptValidKeyLengths() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("01", 16).toCharArray());
final List<Integer> VALID_KEY_LENGTHS = AES_KEY_LENGTHS;
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final int keyLength : VALID_KEY_LENGTHS) {
// Initialize a cipher for encryption
Cipher cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, true);
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"));
cipher = cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, false);
byte[] recoveredBytes = cipher.doFinal(cipherBytes);
String recovered = new String(recoveredBytes, "UTF-8");
// Assert
assertEquals(PLAINTEXT, recovered);
}
}
@Test
void testGetCipherShouldNotAcceptInvalidKeyLengths() throws Exception {
// Arrange
final byte[] SALT = cipherProvider.generateSalt();
final byte[] IV = Hex.decodeHex(StringUtils.repeat("00", 16).toCharArray());
// Even though Scrypt can derive keys of arbitrary length, it will fail to validate if the underlying cipher does not support it
final List<Integer> INVALID_KEY_LENGTHS = Arrays.asList(-1, 40, 64, 112, 512);
// Currently only AES ciphers are compatible with Scrypt, so redundant to test all algorithms
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
for (final int keyLength : INVALID_KEY_LENGTHS) {
// Initialize a cipher for encryption
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, SHORT_PASSWORD, SALT, IV, keyLength, true));
// Assert
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"));
}
}
@Test
void testScryptShouldNotAcceptInvalidPassword() {
// Arrange
String emptyPassword = "";
final byte[] salt = new byte[16];
Arrays.fill(salt, (byte) 0x01);
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC;
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, emptyPassword, salt, DEFAULT_KEY_LENGTH, true));
// Assert
assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"));
}
@Test
void testGenerateSaltShouldUseProvidedParameters() throws Exception {
// Arrange
ScryptCipherProvider cipherProvider = new ScryptCipherProvider(8, 2, 2);
int n = cipherProvider.getN();
int r = cipherProvider.getR();
int p = cipherProvider.getP();
// Act
final String salt = new String(cipherProvider.generateSalt());
// Assert
final Matcher matcher = Pattern.compile("^(?i)\\$s0\\$[a-f0-9]{5,16}\\$").matcher(salt);
assertTrue(matcher.find());
String params = Scrypt.encodeParams(n, r, p);
assertTrue(salt.contains(String.format("$%s$", params)));
}
@Test
void testShouldParseSalt() throws Exception {
// Arrange
ScryptCipherProvider cipherProvider = (ScryptCipherProvider) this.cipherProvider;
final byte[] EXPECTED_RAW_SALT = Hex.decodeHex("f5b8056ea6e66edb8d013ac432aba24a".toCharArray());
final int EXPECTED_N = 1024;
final int EXPECTED_R = 8;
final int EXPECTED_P = 36;
final String FORMATTED_SALT = "$s0$a0824$9bgFbqbmbtuNATrEMquiSg";
byte[] rawSalt = new byte[16];
final List<Integer> params = new ArrayList<>();
// Act
cipherProvider.parseSalt(FORMATTED_SALT, rawSalt, params);
// Assert
assertArrayEquals(EXPECTED_RAW_SALT, rawSalt);
assertEquals(EXPECTED_N, params.get(0));
assertEquals(EXPECTED_R, params.get(1));
assertEquals(EXPECTED_P, params.get(2));
}
@Test
void testShouldVerifyPBoundary() throws Exception {
// Arrange
final int r = 8;
final int p = 1;
// Act
boolean valid = ScryptCipherProvider.isPValid(r, p);
// Assert
assertTrue(valid);
}
@Test
void testShouldFailPBoundary() throws Exception {
// Arrange
// The p upper bound is calculated with the formula below, when r = 8:
// pBoundary = ((Math.pow(2,32))-1) * (32.0/(r * 128)), where pBoundary = 134217727.96875;
final Map<Integer, Integer> costParameters = new HashMap<>();
costParameters.put(8, 134217729);
costParameters.put(128, 8388608);
costParameters.put(4096, 0);
// Act and Assert
costParameters.entrySet().forEach(entry ->
assertFalse(ScryptCipherProvider.isPValid(entry.getKey(), entry.getValue()))
);
}
@Test
void testShouldVerifyRValue() throws Exception {
// Arrange
final int r = 8;
// Act
boolean valid = ScryptCipherProvider.isRValid(r);
// Assert
assertTrue(valid);
}
@Test
void testShouldFailRValue() throws Exception {
// Arrange
final int r = 0;
// Act
boolean valid = ScryptCipherProvider.isRValid(r);
// Assert
assertFalse(valid);
}
@Test
void testShouldValidateScryptCipherProviderPBoundary() throws Exception {
// Arrange
final int n = 64;
final int r = 8;
final int p = 1;
// Act
ScryptCipherProvider testCipherProvider = new ScryptCipherProvider(n, r, p);
// Assert
assertNotNull(testCipherProvider);
}
@Test
void testShouldCatchInvalidP() throws Exception {
// Arrange
final int n = 64;
final int r = 8;
final int p = 0;
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> new ScryptCipherProvider(n, r, p));
// Assert
assertTrue(iae.getMessage().contains("Invalid p value exceeds p boundary"));
}
@Test
void testShouldCatchInvalidR() throws Exception {
// Arrange
final int n = 64;
final int r = 0;
final int p = 0;
// Act
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> new ScryptCipherProvider(n, r, p));
// Assert
assertTrue(iae.getMessage().contains("Invalid r value; must be greater than 0"));
}
@Test
void testShouldAcceptFormattedSaltWithPlus() throws Exception {
// Arrange
final String FULL_SALT_WITH_PLUS = "$s0$e0801$smJD8vwWI3+uQCHYz2yg0+";
// Act
boolean isScryptSalt = ScryptCipherProvider.isScryptFormattedSalt(FULL_SALT_WITH_PLUS);
// Assert
assertTrue(isScryptSalt);
}
@EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true",
disabledReason = "This test can be run on a specific machine to evaluate if the default parameters are sufficient")
@Test
void testDefaultConstructorShouldProvideStrongParameters() {
// Arrange
ScryptCipherProvider testCipherProvider = new ScryptCipherProvider();
/** See this Stack Overflow answer for a good visualization of the interplay between N, r, p <a href="http://stackoverflow.com/a/30308723" rel="noopener">http://stackoverflow.com/a/30308723</a> */
// Act
int n = testCipherProvider.getN();
int r = testCipherProvider.getR();
int p = testCipherProvider.getP();
// Calculate the parameters to reach 500 ms
final List<Integer> minParameters = calculateMinimumParameters(r, p, 1024 * 1024 * 1024);
final int minimumN = minParameters.get(0);
// Assert
assertTrue(n >= minimumN, "The default parameters for ScryptCipherProvider are too weak. Please update the default values to a stronger level.");
}
/**
* Returns the parameters required for a derivation to exceed 500 ms on this machine. Code adapted from http://security.stackexchange.com/questions/17207/recommended-of-rounds-for-bcrypt
*
* @param r the block size in bytes
* @param p the parallelization factor
* @param maxHeapSize the maximum heap size to use in bytes (defaults to 1 GB)
*
* @return the minimum scrypt parameters as [N, r, p]
*/
private static List<Integer> calculateMinimumParameters(final int r, final int p, final int maxHeapSize) {
// High start-up cost, so run multiple times for better benchmarking
final int RUNS = 10;
// Benchmark using N=2^4
int n = (int) Math.pow(2, 4);
int dkLen = 128;
assertTrue(Scrypt.calculateExpectedMemory(n, r, p) <= maxHeapSize);
byte[] salt = new byte[Scrypt.getDefaultSaltLength()];
new SecureRandom().nextBytes(salt);
long start;
long end;
double duration;
// Run once to prime the system
Scrypt.scrypt(MICROBENCHMARK, salt, n, r, p, dkLen);
final List<Double> durations = new ArrayList<>();
for (int i = 0; i < RUNS; i++) {
start = System.nanoTime();
Scrypt.scrypt(MICROBENCHMARK, salt, n, r, p, dkLen);
end = System.nanoTime();
duration = getTime(start, end);
durations.add(duration);
}
duration = durations.stream().mapToDouble(Double::doubleValue).sum() / durations.size();
// Doubling N would double the run time
// Keep increasing N until the estimated duration is over 500 ms
while (duration < 500) {
n *= 2;
duration *= 2;
}
return Arrays.asList(n, r, p);
}
private static double getTime(final long start, final long end) {
return (end - start) / 1_000_000.0;
}
}

View File

@ -0,0 +1,387 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto;
import org.bouncycastle.util.encoders.Hex;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class ScryptSecureHasherTest {
private static final byte[] STATIC_SALT = "NiFi Static Salt".getBytes(StandardCharsets.UTF_8);
private static final String SENSITIVE_VALUE = "This is a sensitive value";
@Test
void testShouldBeDeterministicWithStaticSalt() {
// Arrange
int n = 1024;
int r = 8;
int p = 2;
int dkLength = 32;
int testIterations = 10;
byte[] inputBytes = SENSITIVE_VALUE.getBytes();
final String EXPECTED_HASH_HEX = "a67fd2f4b3aa577b8ecdb682e60b4451a84611dcbbc534bce17616056ef8965d";
ScryptSecureHasher scryptSH = new ScryptSecureHasher(n, r, p, dkLength);
final List<String> results = new ArrayList<>();
// Act
for (int i = 0; i < testIterations; i++) {
byte[] hash = scryptSH.hashRaw(inputBytes);
String hashHex = new String(Hex.encode(hash));
results.add(hashHex);
}
// Assert
results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result));
}
@Test
void testShouldBeDifferentWithRandomSalt() {
// Arrange
int n = 1024;
int r = 8;
int p = 2;
int dkLength = 128;
int testIterations = 10;
byte[] inputBytes = SENSITIVE_VALUE.getBytes();
final String EXPECTED_HASH_HEX = "a67fd2f4b3aa577b8ecdb682e60b4451";
ScryptSecureHasher scryptSH = new ScryptSecureHasher(n, r, p, dkLength, 16);
final List<String> results = new ArrayList<>();
// Act
for (int i = 0; i < testIterations; i++) {
byte[] hash = scryptSH.hashRaw(inputBytes);
String hashHex = new String(Hex.encode(hash));
results.add(hashHex);
}
// Assert
assertTrue(results.stream().distinct().collect(Collectors.toList()).size() == results.size());
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result));
}
@Test
void testShouldHandleArbitrarySalt() {
// Arrange
int n = 1024;
int r = 8;
int p = 2;
int dkLength = 32;
final String input = SENSITIVE_VALUE;
byte[] inputBytes = input.getBytes();
final String EXPECTED_HASH_HEX = "a67fd2f4b3aa577b8ecdb682e60b4451a84611dcbbc534bce17616056ef8965d";
final String EXPECTED_HASH_BASE64 = "pn/S9LOqV3uOzbaC5gtEUahGEdy7xTS84XYWBW74ll0";
final byte[] EXPECTED_HASH_BYTES = Hex.decode(EXPECTED_HASH_HEX);
// Static salt instance
ScryptSecureHasher staticSaltHasher = new ScryptSecureHasher(n, r, p, dkLength);
ScryptSecureHasher arbitrarySaltHasher = new ScryptSecureHasher(n, r, p, dkLength, 16);
final String DIFFERENT_STATIC_SALT = "Diff Static Salt";
// Act
byte[] staticSaltHash = staticSaltHasher.hashRaw(inputBytes);
byte[] arbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, STATIC_SALT);
byte[] differentArbitrarySaltHash = arbitrarySaltHasher.hashRaw(inputBytes, DIFFERENT_STATIC_SALT.getBytes(StandardCharsets.UTF_8));
byte[] differentSaltHash = arbitrarySaltHasher.hashRaw(inputBytes);
String staticSaltHashHex = staticSaltHasher.hashHex(input);
String arbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8));
String differentArbitrarySaltHashHex = arbitrarySaltHasher.hashHex(input, DIFFERENT_STATIC_SALT);
String differentSaltHashHex = arbitrarySaltHasher.hashHex(input);
String staticSaltHashBase64 = staticSaltHasher.hashBase64(input);
String arbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8));
String differentArbitrarySaltHashBase64 = arbitrarySaltHasher.hashBase64(input, DIFFERENT_STATIC_SALT);
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input);
// Assert
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash);
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash);
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash));
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash));
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex);
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex);
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex);
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex);
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64);
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64);
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64);
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64);
}
@Test
void testShouldValidateArbitrarySalt() {
// Arrange
int n = 1024;
int r = 8;
int p = 2;
int dkLength = 32;
final String input = SENSITIVE_VALUE;
byte[] inputBytes = input.getBytes();
// Static salt instance
ScryptSecureHasher secureHasher = new ScryptSecureHasher(n, r, p, dkLength, 16);
final byte[] STATIC_SALT = "bad_sal".getBytes();
assertThrows(IllegalArgumentException.class, () -> new ScryptSecureHasher(n, r, p, dkLength, 7));
assertThrows(RuntimeException.class, () -> secureHasher.hashRaw(inputBytes, STATIC_SALT));
assertThrows(RuntimeException.class, () -> secureHasher.hashHex(input, new String(STATIC_SALT, StandardCharsets.UTF_8)));
assertThrows(RuntimeException.class, () -> secureHasher.hashBase64(input, new String(STATIC_SALT, StandardCharsets.UTF_8)));
}
@Test
void testShouldFormatHex() {
// Arrange
String input = SENSITIVE_VALUE;
final String EXPECTED_HASH_HEX = "6a9c827815fe0718af5e336811fc78dd719c8d9505e015283239b9bf1d24ee71";
SecureHasher scryptSH = new ScryptSecureHasher();
// Act
String hashHex = scryptSH.hashHex(input);
// Assert
assertEquals(EXPECTED_HASH_HEX, hashHex);
}
@Test
void testShouldFormatBase64() {
// Arrange
String input = SENSITIVE_VALUE;
final String EXPECTED_HASH_BASE64 = "apyCeBX+BxivXjNoEfx43XGcjZUF4BUoMjm5vx0k7nE";
SecureHasher scryptSH = new ScryptSecureHasher();
// Act
String hashB64 = scryptSH.hashBase64(input);
// Assert
assertEquals(EXPECTED_HASH_BASE64, hashB64);
}
@Test
void testShouldHandleNullInput() {
// Arrange
List<String> inputs = Arrays.asList(null, "");
final String EXPECTED_HASH_HEX = "";
final String EXPECTED_HASH_BASE64 = "";
ScryptSecureHasher scryptSH = new ScryptSecureHasher();
final List<String> hexResults = new ArrayList<>();
final List<String> B64Results = new ArrayList<>();
// Act
for (final String input : inputs) {
String hashHex = scryptSH.hashHex(input);
hexResults.add(hashHex);
String hashB64 = scryptSH.hashBase64(input);
B64Results.add(hashB64);
}
// Assert
hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result));
B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result));
}
/**
* This test can have the minimum time threshold updated to determine if the performance
* is still sufficient compared to the existing threat model.
*/
@EnabledIfSystemProperty(named = "nifi.test.performance", matches = "true")
@Test
void testDefaultCostParamsShouldBeSufficient() {
// Arrange
int testIterations = 100;
byte[] inputBytes = SENSITIVE_VALUE.getBytes();
ScryptSecureHasher scryptSH = new ScryptSecureHasher();
final List<String> results = new ArrayList<>();
final List<Long> resultDurations = new ArrayList<>();
// Act
for (int i = 0; i < testIterations; i++) {
long startNanos = System.nanoTime();
byte[] hash = scryptSH.hashRaw(inputBytes);
long endNanos = System.nanoTime();
long durationNanos = endNanos - startNanos;
String hashHex = new String(Hex.encode(hash));
results.add(hashHex);
resultDurations.add(durationNanos);
}
// Assert
final long MIN_DURATION_NANOS = 75_000_000; // 75 ms
assertTrue(Collections.min(resultDurations) > MIN_DURATION_NANOS);
assertTrue(resultDurations.stream().mapToLong(Long::longValue).sum() / testIterations > MIN_DURATION_NANOS);
}
@Test
void testShouldVerifyRBoundary() throws Exception {
// Arrange
final int r = 32;
// Act
boolean valid = ScryptSecureHasher.isRValid(r);
// Assert
assertTrue(valid);
}
@Test
void testShouldFailRBoundary() throws Exception {
// Arrange
List<Integer> rValues = Arrays.asList(-8, 0, 2147483647);
// Act and Assert
rValues.forEach(rValue -> assertFalse(ScryptSecureHasher.isRValid(rValue)));
}
@Test
void testShouldVerifyNBoundary() throws Exception {
// Arrange
final Integer n = 16385;
final int r = 8;
// Act and Assert
assertTrue(ScryptSecureHasher.isNValid(n, r));
}
@Test
void testShouldFailNBoundary() throws Exception {
// Arrange
final Map<Integer, Integer> costParameters = new HashMap<>();
costParameters.put(-8, 8);
costParameters.put(0, 32);
//Act and Assert
costParameters.entrySet().forEach(entry ->
assertFalse(ScryptSecureHasher.isNValid(entry.getKey(), entry.getValue()))
);
}
@Test
void testShouldVerifyPBoundary() throws Exception {
// Arrange
final List<Integer> ps = Arrays.asList(1, 8, 1024);
final List<Integer> rs = Arrays.asList(8, 1024, 4096);
// Act and Assert
ps.forEach(p ->
rs.forEach(r ->
assertTrue(ScryptSecureHasher.isPValid(p, r))
)
);
}
@Test
void testShouldFailIfPBoundaryExceeded() throws Exception {
// Arrange
final List<Integer> ps = Arrays.asList(4096 * 64, 1024 * 1024);
final List<Integer> rs = Arrays.asList(4096, 1024 * 1024);
// Act and Assert
ps.forEach(p ->
rs.forEach(r ->
assertFalse(ScryptSecureHasher.isPValid(p, r))
)
);
}
@Test
void testShouldVerifyDKLengthBoundary() throws Exception {
// Arrange
final Integer dkLength = 64;
// Act
boolean valid = ScryptSecureHasher.isDKLengthValid(dkLength);
// Assert
assertTrue(valid);
}
@Test
void testShouldFailDKLengthBoundary() throws Exception {
// Arrange
final List<Integer> dKLengths = Arrays.asList(-8, 0, 2147483647);
// Act and Assert
dKLengths.forEach(dKLength ->
assertFalse(ScryptSecureHasher.isDKLengthValid(dKLength))
);
}
@Test
void testShouldVerifySaltLengthBoundary() throws Exception {
// Arrange
final List<Integer> saltLengths = Arrays.asList(0, 64);
// Act and Assert
ScryptSecureHasher scryptSecureHasher = new ScryptSecureHasher();
saltLengths.forEach(saltLength ->
assertTrue(scryptSecureHasher.isSaltLengthValid(saltLength))
);
}
@Test
void testShouldFailSaltLengthBoundary() throws Exception {
// Arrange
final List<Integer> saltLengths = Arrays.asList(-8, 1, 2147483647);
// Act and Assert
ScryptSecureHasher scryptSecureHasher = new ScryptSecureHasher();
saltLengths.forEach(saltLength ->
assertFalse(scryptSecureHasher.isSaltLengthValid(saltLength))
);
}
}

View File

@ -0,0 +1,450 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License") you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.scrypt;
import org.apache.commons.codec.DecoderException;
import org.apache.commons.codec.binary.Hex;
import org.apache.nifi.security.util.crypto.scrypt.Scrypt;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
public class ScryptTest {
private static final String PASSWORD = "shortPassword";
private static final String SALT_HEX = "0123456789ABCDEFFEDCBA9876543210";
private static byte[] SALT_BYTES;
// Small values to test for correctness, not timing
private static final int N = (int) Math.pow(2, 4);
private static final int R = 1;
private static final int P = 1;
private static final int DK_LEN = 128;
private static final long TWO_GIGABYTES = 2048L * 1024 * 1024;
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider());
SALT_BYTES = Hex.decodeHex(SALT_HEX.toCharArray());
}
@Test
void testDeriveScryptKeyShouldBeInternallyConsistent() throws Exception {
// Arrange
final List<byte[]> allKeys = new ArrayList<>();
final int RUNS = 10;
// Act
for (int i = 0; i < RUNS; i++) {
byte[] keyBytes = Scrypt.deriveScryptKey(PASSWORD.getBytes(), SALT_BYTES, N, R, P, DK_LEN);
allKeys.add(keyBytes);
}
// Assert
assertEquals(RUNS, allKeys.size());
allKeys.forEach(key -> assertArrayEquals(allKeys.get(0), key));
}
/**
* This test ensures that the local implementation of Scrypt is compatible with the reference implementation from the Colin Percival paper.
*/
@Test
void testDeriveScryptKeyShouldMatchTestVectors() throws DecoderException, GeneralSecurityException {
// Arrange
// These values are taken from Colin Percival's scrypt paper: https://www.tarsnap.com/scrypt/scrypt.pdf
final byte[] HASH_2 = Hex.decodeHex("fdbabe1c9d3472007856e7190d01e9fe" +
"7c6ad7cbc8237830e77376634b373162" +
"2eaf30d92e22a3886ff109279d9830da" +
"c727afb94a83ee6d8360cbdfa2cc0640");
final byte[] HASH_3 = Hex.decodeHex("7023bdcb3afd7348461c06cd81fd38eb" +
"fda8fbba904f8e3ea9b543f6545da1f2" +
"d5432955613f0fcf62d49705242a9af9" +
"e61e85dc0d651e40dfcf017b45575887");
final List<TestVector> TEST_VECTORS = new ArrayList<>();
TEST_VECTORS.add(new TestVector(
"password",
"NaCl",
1024,
8,
16,
64 * 8,
HASH_2
));
TEST_VECTORS.add(new TestVector(
"pleaseletmein",
"SodiumChloride",
16384,
8,
1,
64 * 8,
HASH_3
));
// Act
for (final TestVector params: TEST_VECTORS) {
long memoryInBytes = Scrypt.calculateExpectedMemory(params.getN(), params.getR(), params.getP());
byte[] calculatedHash = Scrypt.deriveScryptKey(params.getPassword().getBytes(), params.getSalt().getBytes(), params.getN(), params.getR(), params.getP(), params.getDkLen());
// Assert
assertArrayEquals(params.hash, calculatedHash);
}
}
/**
* This test ensures that the local implementation of Scrypt is compatible with the reference implementation from the Colin Percival paper. The test vector requires ~1GB {@code byte[]}
* and therefore the Java heap must be at least 1GB. Because nifi/pom.xml has a {@code surefire} rule which appends {@code -Xmx1G}
* to the Java options, this overrides any IDE options. To ensure the heap is properly set, using the {@code groovyUnitTest} profile will re-append {@code -Xmx3072m} to the Java options.
*/
@Test
void testDeriveScryptKeyShouldMatchExpensiveTestVector() throws Exception {
// Arrange
long totalMemory = Runtime.getRuntime().totalMemory();
assumeTrue(totalMemory >= TWO_GIGABYTES, "Test is being skipped due to JVM heap size. Please run with -Xmx3072m to set sufficient heap size");
// These values are taken from Colin Percival's scrypt paper: https://www.tarsnap.com/scrypt/scrypt.pdf
final byte[] HASH = Hex.decodeHex("2101cb9b6a511aaeaddbbe09cf70f881" +
"ec568d574a2ffd4dabe5ee9820adaa47" +
"8e56fd8f4ba5d09ffa1c6d927c40f4c3" +
"37304049e8a952fbcbf45c6fa77a41a4".toCharArray());
// This test vector requires 2GB heap space and approximately 10 seconds on a consumer machine
String password = "pleaseletmein";
String salt = "SodiumChloride";
int n = 1048576;
int r = 8;
int p = 1;
int dkLen = 64 * 8;
// Act
long memoryInBytes = Scrypt.calculateExpectedMemory(n, r, p);
byte[] calculatedHash = Scrypt.deriveScryptKey(password.getBytes(), salt.getBytes(), n, r, p, dkLen);
// Assert
assertArrayEquals(HASH, calculatedHash);
}
@EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true")
@Test
void testShouldCauseOutOfMemoryError() {
SecureRandom secureRandom = new SecureRandom();
for (int i = 10; i <= 31; i++) {
int length = (int) Math.pow(2, i);
byte[] bytes = new byte[length];
secureRandom.nextBytes(bytes);
}
}
@Test
void testDeriveScryptKeyShouldSupportExternalCompatibility() throws Exception {
// Arrange
// These values can be generated by running `$ ./openssl_scrypt.rb` in the terminal
final String EXPECTED_KEY_HEX = "a8efbc0a709d3f89b6bb35b05fc8edf5";
String password = "thisIsABadPassword";
String saltHex = "f5b8056ea6e66edb8d013ac432aba24a";
int n = 1024;
int r = 8;
int p = 36;
int dkLen = 16 * 8;
// Act
long memoryInBytes = Scrypt.calculateExpectedMemory(n, r, p);
byte[] calculatedHash = Scrypt.deriveScryptKey(password.getBytes(), Hex.decodeHex(saltHex.toCharArray()), n, r, p, dkLen);
// Assert
assertArrayEquals(Hex.decodeHex(EXPECTED_KEY_HEX.toCharArray()), calculatedHash);
}
@Test
void testScryptShouldBeInternallyConsistent() throws Exception {
// Arrange
final List<String> allHashes = new ArrayList<>();
final int RUNS = 10;
// Act
for (int i = 0; i < RUNS; i++) {
String hash = Scrypt.scrypt(PASSWORD, SALT_BYTES, N, R, P, DK_LEN);
allHashes.add(hash);
}
// Assert
assertEquals(RUNS, allHashes.size());
allHashes.forEach(hash -> assertEquals(allHashes.get(0), hash));
}
@Test
void testScryptShouldGenerateValidSaltIfMissing() {
// Arrange
// The generated salt should be byte[16], encoded as 22 Base64 chars
final String EXPECTED_SALT_PATTERN = "\\$.+\\$[0-9a-zA-Z\\/+]{22}\\$.+";
// Act
String calculatedHash = Scrypt.scrypt(PASSWORD, N, R, P, DK_LEN);
// Assert
System.out.println(calculatedHash);
final Matcher matcher = Pattern.compile(EXPECTED_SALT_PATTERN).matcher(calculatedHash);
assertTrue(matcher.matches());
}
@Test
void testScryptShouldNotAcceptInvalidN() throws Exception {
// Arrange
final int MAX_N = Integer.MAX_VALUE / 128 / R ;
// N must be a power of 2 > 1 and < Integer.MAX_VALUE / 128 / r
final List<Integer> INVALID_NS = Arrays.asList(-2, 0, 1, 3, 4096 - 1, MAX_N);
// Act
for (final int invalidN : INVALID_NS) {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.deriveScryptKey(PASSWORD.getBytes(), SALT_BYTES, invalidN, R, P, DK_LEN));
// Assert
assertTrue(iae.getMessage().contains("N must be a power of 2 greater than 1")
|| iae.getMessage().contains("Parameter N is too large"));
}
}
@Test
void testScryptShouldAcceptValidR() throws Exception {
// Arrange
// Use a large p value to allow r to exceed MAX_R without normal N exceeding MAX_N
int largeP = 1024;
final int maxR = 16384;
// r must be in (0..Integer.MAX_VALUE / 128 / p)
final List<Integer> INVALID_RS = Arrays.asList(0, maxR);
// Act
for (final int invalidR : INVALID_RS) {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.deriveScryptKey(PASSWORD.getBytes(), SALT_BYTES, N, invalidR, largeP, DK_LEN));
// Assert
assertTrue(iae.getMessage().contains("Parameter r must be 1 or greater")
|| iae.getMessage().contains("Parameter r is too large"));
}
}
@Test
void testScryptShouldNotAcceptInvalidP() throws Exception {
// Arrange
final int MAX_P = (int) (Math.ceil(Integer.MAX_VALUE / 128.0));
// p must be in (0..Integer.MAX_VALUE / 128)
final List<Integer> INVALID_PS = Arrays.asList(0, MAX_P);
// Act
for (final int invalidP : INVALID_PS) {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.deriveScryptKey(PASSWORD.getBytes(), SALT_BYTES, N, R, invalidP, DK_LEN));
// Assert
assertTrue(iae.getMessage().contains("Parameter p must be 1 or greater")
|| iae.getMessage().contains("Parameter p is too large"));
}
}
@Test
void testCheckShouldValidateCorrectPassword() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword";
final String EXPECTED_HASH = Scrypt.scrypt(PASSWORD, N, R, P, DK_LEN);
// Act
boolean matches = Scrypt.check(PASSWORD, EXPECTED_HASH);
// Assert
assertTrue(matches);
}
@Test
void testCheckShouldNotValidateIncorrectPassword() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword";
final String INCORRECT_PASSWORD = "incorrectPassword";
final String EXPECTED_HASH = Scrypt.scrypt(PASSWORD, N, R, P, DK_LEN);
// Act
boolean matches = Scrypt.check(INCORRECT_PASSWORD, EXPECTED_HASH);
// Assert
assertFalse(matches);
}
@Test
void testCheckShouldNotAcceptInvalidPassword() throws Exception {
// Arrange
final String HASH = "$s0$a0801$abcdefghijklmnopqrstuv$abcdefghijklmnopqrstuv";
// Even though the spec allows for empty passwords, the JCE does not, so extend enforcement of that to the user boundary
final List<String> INVALID_PASSWORDS = Arrays.asList("", null);
// Act
for (final String invalidPassword : INVALID_PASSWORDS) {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.check(invalidPassword, HASH));
// Assert
assertTrue(iae.getMessage().contains("Password cannot be empty"));
}
}
@Test
void testCheckShouldNotAcceptInvalidHash() throws Exception {
// Arrange
final String PASSWORD = "thisIsABadPassword";
// Even though the spec allows for empty salts, the JCE does not, so extend enforcement of that to the user boundary
final List<String> INVALID_HASHES = Arrays.asList("", null, "$s0$a0801$", "$s0$a0801$abcdefghijklmnopqrstuv$");
// Act
for (final String invalidHash : INVALID_HASHES) {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.check(PASSWORD, invalidHash));
// Assert
assertTrue(iae.getMessage().contains("Hash cannot be empty")
|| iae.getMessage().contains("Hash is not properly formatted"));
}
}
@Test
void testVerifyHashFormatShouldDetectValidHash() throws Exception {
// Arrange
final List<String> VALID_HASHES = Arrays.asList(
"$s0$40801$AAAAAAAAAAAAAAAAAAAAAA$gLSh7ChbHdOIMvZ74XGjV6qF65d9qvQ8n75FeGnM8YM",
"$s0$40801$ABCDEFGHIJKLMNOPQRSTUQ$hxU5g0eH6sRkBqcsiApI8jxvKRT+2QMCenV0GToiMQ8",
"$s0$40801$eO+UUcKYL2gnpD51QCc+gnywQ7Eg9tZeLMlf0XXr2zc$99aTTB39TJo69aZCONQmRdyWOgYsDi+1MI+8D0EgMNM",
"$s0$40801$AAAAAAAAAAAAAAAAAAAAAA$Gk7K9YmlsWbd8FS7e4RKVWnkg9vlsqYnlD593pJ71gg",
"$s0$40801$ABCDEFGHIJKLMNOPQRSTUQ$Ri78VZbrp2cCVmGh2a9Nbfdov8LPnFb49MYyzPCaXmE",
"$s0$40801$eO+UUcKYL2gnpD51QCc+gnywQ7Eg9tZeLMlf0XXr2zc$rZIrP2qdIY7LN4CZAMgbCzl3YhXz6WhaNyXJXqFIjaI",
"$s0$40801$AAAAAAAAAAAAAAAAAAAAAA$GxH68bGykmPDZ6gaPIGOONOT2omlZ7cd0xlcZ9UsY/0",
"$s0$40801$ABCDEFGHIJKLMNOPQRSTUQ$KLGZjWlo59sbCbtmTg5b4k0Nu+biWZRRzhPhN7K5kkI",
"$s0$40801$eO+UUcKYL2gnpD51QCc+gnywQ7Eg9tZeLMlf0XXr2zc$6Ql6Efd2ac44ERoV31CL3Q0J3LffNZKN4elyMHux99Y",
// Uncommon but technically valid
"$s0$F0801$AAAAAAAAAAA$A",
"$s0$40801$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP$A",
"$s0$40801$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP",
"$s0$40801$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP$" +
"ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP",
"$s0$F0801$AAAAAAAAAAA$A",
"$s0$F0801$AAAAAAAAAAA$A",
"$s0$F0801$AAAAAAAAAAA$A",
"$s0$F0801$AAAAAAAAAAA$A",
"$s0$F0801$AAAAAAAAAAA$A"
);
// Act
for (final String validHash : VALID_HASHES) {
boolean isValidHash = Scrypt.verifyHashFormat(validHash);
// Assert
assertTrue(isValidHash);
}
}
@Test
void testVerifyHashFormatShouldDetectInvalidHash() throws Exception {
// Arrange
// Even though the spec allows for empty salts, the JCE does not, so extend enforcement of that to the user boundary
final List<String> INVALID_HASHES = Arrays.asList("", null, "$s0$a0801$", "$s0$a0801$abcdefghijklmnopqrstuv$");
// Act
for (final String invalidHash : INVALID_HASHES) {
boolean isValidHash = Scrypt.verifyHashFormat(invalidHash);
// Assert
assertFalse(isValidHash);
}
}
private class TestVector {
private String password;
private String salt;
private int n;
private int r;
private int p;
private int dkLen;
private byte[] hash;
public TestVector(String password, String salt, int n, int r, int p, int dkLen, byte[] hash) {
this.password = password;
this.salt = salt;
this.n = n;
this.r = r;
this.p = p;
this.dkLen = dkLen;
this.hash = hash;
}
public String getPassword() {
return password;
}
public String getSalt() {
return salt;
}
public int getN() {
return n;
}
public int getR() {
return r;
}
public int getP() {
return p;
}
public int getDkLen() {
return dkLen;
}
public byte[] getHash() {
return hash;
}
}
}