NIFI-11035 Replaced remaining JUnit 4 assertions in nifi-commons with JUnit 5

- Replaced Groovy asserts with JUnit 5 assertions and Groovy shouldFail method Junit 5 with assertThrow method

This closes #6880

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
dan-s1 2023-01-23 22:28:49 +00:00 committed by exceptionfactory
parent 84f48b5e8c
commit 53371844a4
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
44 changed files with 924 additions and 1102 deletions

View File

@ -25,7 +25,10 @@ import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
class QueryGroovyTest extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertNotEquals
class QueryGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(QueryGroovyTest.class)
@BeforeAll
@ -79,11 +82,11 @@ class QueryGroovyTest extends GroovyTestCase {
logger.info("Replace repeating result: ${replaceRepeatingResult.value}")
// Assert
assert replaceSingleResult.value == EXPECTED_SINGLE_RESULT
assert replaceSingleResult.resultType == AttributeExpression.ResultType.STRING
assertEquals(EXPECTED_SINGLE_RESULT, replaceSingleResult.value)
assertEquals(AttributeExpression.ResultType.STRING, replaceSingleResult.resultType)
assert replaceRepeatingResult.value == EXPECTED_REPEATING_RESULT
assert replaceRepeatingResult.resultType == AttributeExpression.ResultType.STRING
assertEquals(EXPECTED_REPEATING_RESULT, replaceRepeatingResult.value)
assertEquals(AttributeExpression.ResultType.STRING, replaceRepeatingResult.resultType)
}
@Test
@ -119,11 +122,11 @@ class QueryGroovyTest extends GroovyTestCase {
logger.info("Replace repeating result: ${replaceRepeatingResult.value}")
// Assert
assert replaceSingleResult.value == EXPECTED_SINGLE_RESULT
assert replaceSingleResult.resultType == AttributeExpression.ResultType.STRING
assertEquals(EXPECTED_SINGLE_RESULT, replaceSingleResult.value)
assertEquals(AttributeExpression.ResultType.STRING, replaceSingleResult.resultType)
assert replaceRepeatingResult.value == EXPECTED_REPEATING_RESULT
assert replaceRepeatingResult.resultType == AttributeExpression.ResultType.STRING
assertEquals(EXPECTED_REPEATING_RESULT, replaceRepeatingResult.value)
assertEquals(AttributeExpression.ResultType.STRING, replaceRepeatingResult.resultType)
}
@Test
@ -159,11 +162,11 @@ class QueryGroovyTest extends GroovyTestCase {
logger.info("Replace repeating result: ${replaceRepeatingResult.value}")
// Assert
assert replaceSingleResult.value == EXPECTED_SINGLE_RESULT
assert replaceSingleResult.resultType == AttributeExpression.ResultType.STRING
assertEquals(EXPECTED_SINGLE_RESULT, replaceSingleResult.value)
assertEquals(AttributeExpression.ResultType.STRING, replaceSingleResult.resultType)
assert replaceRepeatingResult.value == EXPECTED_REPEATING_RESULT
assert replaceRepeatingResult.resultType == AttributeExpression.ResultType.STRING
assertEquals(EXPECTED_REPEATING_RESULT, replaceRepeatingResult.value)
assertEquals(AttributeExpression.ResultType.STRING, replaceRepeatingResult.resultType)
}
@Test
@ -200,10 +203,10 @@ class QueryGroovyTest extends GroovyTestCase {
logger.info("Replace repeating result: ${replaceFirstRepeatingResult}")
// Assert
assert replaceSingleResult != EXPECTED_SINGLE_RESULT
assert replaceRepeatingResult != EXPECTED_REPEATING_RESULT
assertNotEquals(EXPECTED_SINGLE_RESULT, replaceSingleResult)
assertNotEquals(EXPECTED_REPEATING_RESULT, replaceRepeatingResult)
assert replaceFirstSingleResult == EXPECTED_SINGLE_RESULT
assert replaceFirstRepeatingResult == EXPECTED_REPEATING_RESULT
assertEquals(EXPECTED_SINGLE_RESULT, replaceFirstSingleResult)
assertEquals(EXPECTED_REPEATING_RESULT, replaceFirstRepeatingResult)
}
}

View File

@ -23,9 +23,9 @@ import org.junit.jupiter.api.Test;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestValueLookup {

View File

@ -21,12 +21,12 @@ import org.junit.jupiter.api.Test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestPackageUnpackageV3 {
@ -35,7 +35,7 @@ public class TestPackageUnpackageV3 {
final FlowFilePackager packager = new FlowFilePackagerV3();
final FlowFileUnpackager unpackager = new FlowFileUnpackagerV3();
final byte[] data = "Hello, World!".getBytes("UTF-8");
final byte[] data = "Hello, World!".getBytes(StandardCharsets.UTF_8);
final Map<String, String> map = new HashMap<>();
map.put("abc", "cba");
@ -50,7 +50,7 @@ public class TestPackageUnpackageV3 {
final byte[] decoded = decodedOut.toByteArray();
assertEquals(map, unpackagedAttributes);
assertTrue(Arrays.equals(data, decoded));
assertArrayEquals(data, decoded);
}
}

View File

@ -28,15 +28,14 @@ import org.apache.nifi.hl7.model.HL7Message;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@SuppressWarnings("resource")
public class TestHL7Query {
@ -60,7 +59,7 @@ public class TestHL7Query {
private HL7Message hypoglycemia;
@BeforeEach
public void init() throws IOException, HL7Exception {
public void init() throws HL7Exception {
this.hyperglycemia = createMessage(HYPERGLYCEMIA);
this.hypoglycemia = createMessage(HYPOGLYCEMIA);
}
@ -116,7 +115,7 @@ public class TestHL7Query {
}
@Test
public void testSelectMessage() throws HL7Exception, IOException {
public void testSelectMessage() {
final HL7Query query = HL7Query.compile("SELECT MESSAGE");
final HL7Message msg = hypoglycemia;
final QueryResult result = query.evaluate(msg);
@ -131,7 +130,7 @@ public class TestHL7Query {
@Test
@SuppressWarnings({"unchecked", "rawtypes"})
public void testSelectField() throws HL7Exception, IOException {
public void testSelectField() {
final HL7Query query = HL7Query.compile("SELECT PID.5");
final HL7Message msg = hypoglycemia;
final QueryResult result = query.evaluate(msg);
@ -149,7 +148,7 @@ public class TestHL7Query {
}
@Test
public void testSelectAbnormalTestResult() throws HL7Exception, IOException {
public void testSelectAbnormalTestResult() {
final String query = "DECLARE result AS REQUIRED OBX SELECT result WHERE result.7 != 'N' AND result.1 = 1";
final HL7Query hl7Query = HL7Query.compile(query);
@ -158,7 +157,7 @@ public class TestHL7Query {
}
@Test
public void testFieldEqualsString() throws HL7Exception, IOException {
public void testFieldEqualsString() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.7 = 'L'");
QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch());
@ -169,7 +168,7 @@ public class TestHL7Query {
}
@Test
public void testLessThan() throws HL7Exception, IOException {
public void testLessThan() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 < 600");
QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch());
@ -180,7 +179,7 @@ public class TestHL7Query {
}
@Test
public void testCompareTwoFields() throws HL7Exception, IOException {
public void testCompareTwoFields() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 < result.6.2");
QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch());
@ -191,7 +190,7 @@ public class TestHL7Query {
}
@Test
public void testLessThanOrEqual() throws HL7Exception, IOException {
public void testLessThanOrEqual() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 <= 59");
QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch());
@ -206,7 +205,7 @@ public class TestHL7Query {
}
@Test
public void testGreaterThanOrEqual() throws HL7Exception, IOException {
public void testGreaterThanOrEqual() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 >= 59");
QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch());
@ -221,7 +220,7 @@ public class TestHL7Query {
}
@Test
public void testGreaterThan() throws HL7Exception, IOException {
public void testGreaterThan() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 > 58");
QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch());
@ -236,7 +235,7 @@ public class TestHL7Query {
}
@Test
public void testDistinctValuesReturned() throws HL7Exception, IOException {
public void testDistinctValuesReturned() {
HL7Query hl7Query = HL7Query.compile("DECLARE result1 AS REQUIRED OBX, result2 AS REQUIRED OBX SELECT MESSAGE WHERE result1.7 = 'L' OR result2.7 != 'H'");
QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch());
@ -244,7 +243,7 @@ public class TestHL7Query {
}
@Test
public void testAndWithParents() throws HL7Exception, IOException {
public void testAndWithParents() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.7 = 'L' AND result.3.1 = 'GLU'");
QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch());
@ -276,7 +275,7 @@ public class TestHL7Query {
}
@Test
public void testIsNull() throws HL7Exception, IOException {
public void testIsNull() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.999 IS NULL");
QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch());
@ -295,7 +294,7 @@ public class TestHL7Query {
}
@Test
public void testNotNull() throws HL7Exception, IOException {
public void testNotNull() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.999 NOT NULL");
QueryResult result = hl7Query.evaluate(hypoglycemia);
assertFalse(result.isMatch());
@ -313,7 +312,7 @@ public class TestHL7Query {
assertTrue(result.isMatch());
}
private HL7Message createMessage(final String msgText) throws HL7Exception, IOException {
private HL7Message createMessage(final String msgText) throws HL7Exception {
final HapiContext hapiContext = new DefaultHapiContext();
hapiContext.setValidationContext(ValidationContextFactory.noValidation());

View File

@ -25,8 +25,8 @@ import org.junit.jupiter.api.Test
import static groovy.json.JsonOutput.prettyPrint
import static groovy.json.JsonOutput.toJson
import static org.junit.Assert.assertFalse
import static org.junit.Assert.assertTrue
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.mockito.Mockito.mock
class TestStandardValidators {

View File

@ -30,9 +30,9 @@ import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class KeyedCipherPropertyEncryptorTest {
private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;

View File

@ -26,9 +26,9 @@ import org.junit.jupiter.api.Test;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class PasswordBasedCipherPropertyEncryptorTest {
private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;

View File

@ -23,9 +23,9 @@ import org.junit.jupiter.api.Test;
import java.util.Properties;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class PropertyEncryptorFactoryTest {
private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.MD5_256AES;

View File

@ -21,8 +21,8 @@ import org.junit.jupiter.api.Test;
import javax.crypto.SecretKey;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class StandardPropertySecretKeyProviderTest {
private static final String SEED = String.class.getName();
@ -42,7 +42,7 @@ public class StandardPropertySecretKeyProviderTest {
final SecretKey secretKey = provider.getSecretKey(propertyEncryptionMethod, SEED);
final int secretKeyLength = secretKey.getEncoded().length;
final String message = String.format("Method [%s] Key Length not matched", propertyEncryptionMethod);
assertEquals(message, propertyEncryptionMethod.getHashLength(), secretKeyLength);
assertEquals(propertyEncryptionMethod.getHashLength(), secretKeyLength, message);
}
}

View File

@ -30,7 +30,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* Abstract base class for tests that walk FieldValue hierarchies.

View File

@ -24,8 +24,8 @@ import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class TestFieldValueLogicalPathBuilder extends AbstractWalkerTest {

View File

@ -24,7 +24,7 @@ import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestFieldValueWalker extends AbstractWalkerTest {

View File

@ -29,7 +29,7 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class TestSimpleRecordSchema {

View File

@ -52,10 +52,10 @@ import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
@ -589,7 +589,7 @@ public class ResultSetRecordSetTest {
List<RecordField> fields = new ArrayList<>(concreteRecords.size());
int i = 1;
for (Record record : concreteRecords) {
fields.add(new RecordField("record" + String.valueOf(i), RecordFieldType.RECORD.getRecordDataType(record.getSchema())));
fields.add(new RecordField("record" + i, RecordFieldType.RECORD.getRecordDataType(record.getSchema())));
++i;
}
return fields;
@ -614,7 +614,7 @@ public class ResultSetRecordSetTest {
}
private void thenAllDataTypesMatchInputFieldType(final List<RecordField> inputFields, final RecordSchema resultSchema) {
assertEquals("The number of input fields does not match the number of fields in the result schema.", inputFields.size(), resultSchema.getFieldCount());
assertEquals(inputFields.size(), resultSchema.getFieldCount(), "The number of input fields does not match the number of fields in the result schema.");
for (int i = 0; i < inputFields.size(); ++i) {
assertEquals(inputFields.get(i).getDataType(), resultSchema.getField(i).getDataType());
}
@ -640,7 +640,7 @@ public class ResultSetRecordSetTest {
expectedDataType = RecordFieldType.DECIMAL.getDecimalDataType(decimalDataType.getScale(), decimalDataType.getScale());
}
}
assertEquals("For column " + column.getIndex() + " the converted type is not matching", expectedDataType, actualDataType);
assertEquals(expectedDataType, actualDataType, "For column " + column.getIndex() + " the converted type is not matching");
}
}
@ -648,9 +648,9 @@ public class ResultSetRecordSetTest {
for (RecordField recordField : actualSchema.getFields()) {
if (recordField.getDataType() instanceof ArrayDataType) {
ArrayDataType arrayType = (ArrayDataType) recordField.getDataType();
assertEquals("Array element type for " + recordField.getFieldName()
+ " is not of expected type " + expectedTypes.get(recordField.getFieldName()).toString(),
expectedTypes.get(recordField.getFieldName()), arrayType.getElementType());
assertEquals(expectedTypes.get(recordField.getFieldName()), arrayType.getElementType(),
"Array element type for " + recordField.getFieldName()
+ " is not of expected type " + expectedTypes.get(recordField.getFieldName()).toString());
} else {
fail("RecordField " + recordField.getFieldName() + " is not instance of ArrayDataType");
}
@ -658,7 +658,7 @@ public class ResultSetRecordSetTest {
}
private void thenAllDataTypesAreChoice(final List<RecordField> inputFields, final RecordSchema resultSchema) {
assertEquals("The number of input fields does not match the number of fields in the result schema.", inputFields.size(), resultSchema.getFieldCount());
assertEquals(inputFields.size(), resultSchema.getFieldCount(), "The number of input fields does not match the number of fields in the result schema.");
DataType expectedType = getBroadestChoiceDataType();
for (int i = 0; i < inputFields.size(); ++i) {

View File

@ -16,14 +16,16 @@
*/
package org.apache.nifi.security.util
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
class TlsConfigurationTest extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertTrue
class TlsConfigurationTest {
private static final Logger logger = LoggerFactory.getLogger(TlsConfigurationTest.class)
@BeforeAll
@ -33,17 +35,6 @@ class TlsConfigurationTest extends GroovyTestCase {
}
}
@BeforeEach
void setUp() {
super.setUp()
}
@AfterEach
void tearDown() {
}
@Test
void testShouldParseJavaVersion() {
// Arrange
@ -57,7 +48,8 @@ class TlsConfigurationTest extends GroovyTestCase {
logger.info("Major versions: ${majorVersions}")
// Assert
assert majorVersions == (5..12)
assertTrue(majorVersions.stream()
.allMatch(num -> num >= 5 && num <= 12))
}
@Test
@ -72,9 +64,9 @@ class TlsConfigurationTest extends GroovyTestCase {
// Assert
if (javaMajorVersion < 11) {
assert tlsVersions == ["TLSv1.2"] as String[]
assertArrayEquals(new String[]{"TLSv1.2"}, tlsVersions)
} else {
assert tlsVersions == ["TLSv1.3", "TLSv1.2"] as String[]
assertArrayEquals(new String[]{"TLSv1.3", "TLSv1.2"}, tlsVersions)
}
}
@ -90,9 +82,9 @@ class TlsConfigurationTest extends GroovyTestCase {
// Assert
if (javaMajorVersion < 11) {
assert tlsVersion == "TLSv1.2"
assertEquals("TLSv1.2", tlsVersion)
} else {
assert tlsVersion == "TLSv1.3"
assertEquals("TLSv1.3", tlsVersion)
}
}
}

View File

@ -20,7 +20,11 @@ import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x500.style.IETFUtils
import org.bouncycastle.asn1.x509.*
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.Extensions
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.operator.OperatorCreationException
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequest
@ -35,16 +39,31 @@ import javax.net.ssl.SSLException
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSession
import javax.net.ssl.SSLSocket
import java.security.*
import java.security.InvalidKeyException
import java.security.KeyPair
import java.security.KeyPairGenerator
import java.security.NoSuchAlgorithmException
import java.security.NoSuchProviderException
import java.security.SignatureException
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.concurrent.*
import java.util.concurrent.Callable
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutionException
import java.util.concurrent.Executors
import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import static org.junit.Assert.assertTrue
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertInstanceOf
import static org.junit.jupiter.api.Assertions.assertNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class CertificateUtilsTest extends GroovyTestCase {
class CertificateUtilsTest {
private static final Logger logger = LoggerFactory.getLogger(CertificateUtilsTest.class)
private static final int KEY_SIZE = 2048
@ -84,10 +103,15 @@ class CertificateUtilsTest extends GroovyTestCase {
*
* @param dn the DN
* @return the certificate
* @throws IOException* @throws NoSuchAlgorithmException* @throws java.security.cert.CertificateException* @throws java.security.NoSuchProviderException* @throws java.security.SignatureException* @throws java.security.InvalidKeyException* @throws OperatorCreationException
* @throws IOException* @throws NoSuchAlgorithmException
* @throws java.security.cert.CertificateException*
* @throws java.security.NoSuchProviderException
* @throws java.security.SignatureException
* @throws OperatorCreationException
*/
private
static X509Certificate generateCertificate(String dn) throws IOException, NoSuchAlgorithmException, CertificateException, NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
static X509Certificate generateCertificate(String dn) throws IOException, NoSuchAlgorithmException, CertificateException,
NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
KeyPair keyPair = generateKeyPair()
return CertificateUtils.generateSelfSignedX509Certificate(keyPair, dn, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
}
@ -99,10 +123,16 @@ class CertificateUtilsTest extends GroovyTestCase {
* @param issuerDn the issuer DN
* @param issuerKey the issuer private key
* @return the certificate
* @throws IOException* @throws NoSuchAlgorithmException* @throws CertificateException* @throws NoSuchProviderException* @throws SignatureException* @throws InvalidKeyException* @throws OperatorCreationException
* @throws IOException
* @throws NoSuchAlgorithmException
* @throws CertificateException
* @throws NoSuchProviderException*
* @throws SignatureException* @throws InvalidKeyException
* @throws OperatorCreationException
*/
private
static X509Certificate generateIssuedCertificate(String dn, X509Certificate issuer, KeyPair issuerKey) throws IOException, NoSuchAlgorithmException, CertificateException, NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
static X509Certificate generateIssuedCertificate(String dn, X509Certificate issuer, KeyPair issuerKey) throws IOException,
NoSuchAlgorithmException, CertificateException, NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
KeyPair keyPair = generateKeyPair()
return CertificateUtils.generateIssuedCertificate(dn, keyPair.getPublic(), issuer, issuerKey, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
}
@ -143,8 +173,7 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Converted certificate: ${convertedCertificate.class.canonicalName} ${convertedCertificate.subjectDN.toString()} (${convertedCertificate.getSerialNumber()})")
// Assert
assert convertedCertificate instanceof X509Certificate
assert convertedCertificate == EXPECTED_NEW_CERTIFICATE
assertEquals(EXPECTED_NEW_CERTIFICATE, convertedCertificate)
}
@Test
@ -163,9 +192,9 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Client auth (noneSocket): ${noneClientAuthStatus}")
// Assert
assert needClientAuthStatus == ClientAuth.REQUIRED
assert wantClientAuthStatus == ClientAuth.WANT
assert noneClientAuthStatus == ClientAuth.NONE
assertEquals(ClientAuth.REQUIRED, needClientAuthStatus)
assertEquals(ClientAuth.WANT, wantClientAuthStatus)
assertEquals(ClientAuth.NONE, noneClientAuthStatus)
}
@Test
@ -210,9 +239,7 @@ class CertificateUtilsTest extends GroovyTestCase {
}
// Assert
assert resolvedServerDNs.every { String serverDN ->
CertificateUtils.compareDNs(serverDN, EXPECTED_DN)
}
resolvedServerDNs.stream().forEach(serverDN -> assertTrue(CertificateUtils.compareDNs(serverDN, EXPECTED_DN)))
}
@Test
@ -231,7 +258,7 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Extracted client DN: ${clientDN}")
// Assert
assert !clientDN
assertNull(clientDN)
}
@Test
@ -257,7 +284,7 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Extracted client DN: ${clientDN}")
// Assert
assert CertificateUtils.compareDNs(clientDN, EXPECTED_DN)
assertTrue(CertificateUtils.compareDNs(clientDN, EXPECTED_DN))
}
@Test
@ -280,7 +307,7 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Extracted client DN: ${clientDN}")
// Assert
assert CertificateUtils.compareDNs(clientDN, null)
assertTrue(CertificateUtils.compareDNs(clientDN, null))
}
@ -307,7 +334,7 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Extracted client DN: ${clientDN}")
// Assert
assert CertificateUtils.compareDNs(clientDN, EXPECTED_DN)
assertTrue(CertificateUtils.compareDNs(clientDN, EXPECTED_DN))
}
@Test
@ -326,13 +353,11 @@ class CertificateUtilsTest extends GroovyTestCase {
] as SSLSocket
// Act
def msg = shouldFail(CertificateException) {
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket)
logger.info("Extracted client DN: ${clientDN}")
}
CertificateException ce = assertThrows(CertificateException.class,
() -> CertificateUtils.extractPeerDNFromSSLSocket(mockSocket))
// Assert
assert msg =~ "peer not authenticated"
assertTrue(ce.getMessage().contains("peer not authenticated"))
}
@Test
@ -374,14 +399,13 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.matches("DN 1, empty: ${dn1MatchesEmpty}")
// Assert
assert dn1MatchesSelf
assert dn1MatchesReversed
assert emptyMatchesEmpty
assert nullMatchesNull
assertTrue(dn1MatchesReversed)
assertTrue(emptyMatchesEmpty)
assertTrue(nullMatchesNull)
assert !dn1MatchesDn2
assert !dn1MatchesDn2Reversed
assert !dn1MatchesEmpty
assertFalse(dn1MatchesDn2)
assertFalse(dn1MatchesDn2Reversed)
assertFalse(dn1MatchesEmpty)
}
@Test
@ -545,10 +569,9 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Issued certificate with subject: ${certificate.getSubjectDN().name} and SAN: ${certificate.getSubjectAlternativeNames().join(",")}")
// Assert
assert certificate instanceof X509Certificate
assert certificate.getSubjectDN().name == SUBJECT_DN
assert certificate.getSubjectAlternativeNames().size() == SANS.size()
assert certificate.getSubjectAlternativeNames()*.last().containsAll(SANS)
assertEquals(SUBJECT_DN, certificate.getSubjectDN().name)
assertEquals(SANS.size(), certificate.getSubjectAlternativeNames().size())
assertTrue(certificate.getSubjectAlternativeNames()*.last().containsAll(SANS))
}
@Test
@ -575,9 +598,9 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Unrelated results: ${unrelatedResults}")
// Assert
assert directResults.every()
assert causedResults.every()
assert !unrelatedResults.any()
assertTrue(directResults.every())
assertTrue(causedResults.every())
assertFalse(unrelatedResults.any())
}
@Test

View File

@ -30,7 +30,10 @@ import javax.crypto.spec.SecretKeySpec
import java.security.SecureRandom
import java.security.Security
import static groovy.test.GroovyAssert.shouldFail
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assumptions.assumeTrue
class AESKeyedCipherProviderGroovyTest {
@ -80,7 +83,7 @@ class AESKeyedCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -107,7 +110,7 @@ class AESKeyedCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -148,7 +151,7 @@ class AESKeyedCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
}
@ -161,12 +164,11 @@ class AESKeyedCipherProviderGroovyTest {
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
def msg = shouldFail(IllegalArgumentException) {
cipherProvider.getCipher(encryptionMethod, null, true)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, null, true))
// Assert
assert msg =~ "The key must be specified"
assertTrue(iae.message.contains("The key must be specified"))
}
@Test
@ -175,17 +177,16 @@ class AESKeyedCipherProviderGroovyTest {
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider()
SecretKey localKey = new SecretKeySpec(Hex.decodeHex("0123456789ABCDEF" as char[]), "AES")
assert ![128, 192, 256].contains(localKey.encoded.length)
assertFalse([128, 192, 256].contains(localKey.encoded.length))
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
def msg = shouldFail(IllegalArgumentException) {
cipherProvider.getCipher(encryptionMethod, localKey, true)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, localKey, true))
// Assert
assert msg =~ "The key must be of length \\[128, 192, 256\\]"
assertTrue(iae.message.contains("The key must be of length [128, 192, 256]"))
}
@Test
@ -194,12 +195,11 @@ class AESKeyedCipherProviderGroovyTest {
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider()
// Act
def msg = shouldFail(IllegalArgumentException) {
cipherProvider.getCipher(null, key, true)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(null, key, true))
// Assert
assert msg =~ "The encryption method must be specified"
assertTrue(iae.message.contains("The encryption method must be specified"))
}
@Test
@ -210,12 +210,11 @@ class AESKeyedCipherProviderGroovyTest {
final EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES
// Act
def msg = shouldFail(IllegalArgumentException) {
cipherProvider.getCipher(encryptionMethod, key, true)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, key, true))
// Assert
assert msg =~ " requires a PBECipherProvider"
assertTrue(iae.message.contains("requires a PBECipherProvider"))
}
@Test
@ -244,7 +243,7 @@ class AESKeyedCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
@Test
@ -264,12 +263,11 @@ class AESKeyedCipherProviderGroovyTest {
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def msg = shouldFail(IllegalArgumentException) {
cipher = cipherProvider.getCipher(em, key, false)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, key, false))
// Assert
assert msg =~ "Cannot decrypt without a valid IV"
assertTrue(iae.message.contains("Cannot decrypt without a valid IV"))
}
}
@ -290,13 +288,12 @@ class AESKeyedCipherProviderGroovyTest {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, true)
// Decrypt should fail
def msg = shouldFail(IllegalArgumentException) {
cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, false)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, key, badIV, false))
logger.warn(iae.getMessage())
// Assert
assert msg =~ "Cannot decrypt without a valid IV"
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@ -317,12 +314,11 @@ class AESKeyedCipherProviderGroovyTest {
logger.info("IV after encrypt: ${Hex.encodeHexString(cipher.getIV())}")
// Decrypt should fail
def msg = shouldFail(IllegalArgumentException) {
cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, false)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, key, badIV, false))
logger.warn(iae.getMessage())
// Assert
assert msg =~ "Cannot decrypt without a valid IV"
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}

View File

@ -20,7 +20,6 @@ import org.apache.commons.codec.binary.Base64
import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.EncryptionMethod
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.Assumptions
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
@ -33,9 +32,14 @@ import javax.crypto.spec.SecretKeySpec
import java.nio.charset.StandardCharsets
import java.security.Security
import static groovy.test.GroovyAssert.shouldFail
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class Argon2CipherProviderGroovyTest extends GroovyTestCase {
class Argon2CipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(Argon2CipherProviderGroovyTest.class)
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess"
@ -92,7 +96,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -128,9 +132,9 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes)
logger.sanity("Created cipher text: ${Hex.encodeHexString(rubyCipherBytes)}")
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec)
assert rubyCipher.doFinal(rubyCipherBytes) == PLAINTEXT.bytes
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(rubyCipherBytes))
logger.sanity("Decrypted generated cipher text successfully")
assert rubyCipher.doFinal(cipherBytes) == PLAINTEXT.bytes
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(cipherBytes))
logger.sanity("Decrypted external cipher text successfully")
// $argon2id$v=19$m=memory,t=iterations,p=parallelism$saltB64$hashB64
@ -149,7 +153,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
def saltB64 = hashComponents[4]
byte[] salt = Base64.decodeBase64(saltB64)
logger.info("Salt: ${Hex.encodeHexString(salt)}")
assert salt == SALT
assertArrayEquals(SALT, salt)
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
logger.info("External cipher text: ${CIPHER_TEXT} ${cipherBytes.length}")
@ -161,7 +165,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
@Test
@ -181,12 +185,11 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, true)
// Decrypt should fail
def msg = shouldFail(IllegalArgumentException) {
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false))
// Assert
assert msg =~ "Cannot decrypt without a valid IV"
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@ -214,7 +217,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -244,7 +247,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -263,13 +266,12 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}")
def msg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true))
logger.expected(iae.getMessage())
// Assert
assert msg =~ LENGTH_MESSAGE
assertTrue(iae.getMessage().contains(LENGTH_MESSAGE))
}
}
@ -290,7 +292,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true)
// Assert
assert cipher
assertNotNull(cipher)
}
}
@ -303,14 +305,12 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
def msg =
shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true))
logger.expected(iae.getMessage())
// Assert
assert msg =~ "The salt cannot be empty. To generate a salt, use Argon2CipherProvider#generateSalt()"
assertTrue(iae.getMessage().contains("The salt cannot be empty. To generate a salt, use Argon2CipherProvider#generateSalt()"))
}
@Test
@ -333,13 +333,15 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
// Assert
boolean isValidFormattedSalt = cipherProvider.isArgon2FormattedSalt(fullSalt)
logger.info("Salt is Argon2 format: ${isValidFormattedSalt}")
assert isValidFormattedSalt
assertTrue(isValidFormattedSalt)
boolean fullSaltIsValidLength = FULL_SALT_LENGTH_RANGE.contains(saltBytes.length)
logger.info("Salt length (${fullSalt.length()}) in valid range (${FULL_SALT_LENGTH_RANGE})")
assert fullSaltIsValidLength
assertTrue(fullSaltIsValidLength)
assert rawSaltBytes != [(0x00 as byte) * 16]
byte [] notExpected = new byte[16]
Arrays.fill(notExpected, 0x00 as byte)
assertFalse(Arrays.equals(notExpected, rawSaltBytes))
}
@Test
@ -360,13 +362,12 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def msg = shouldFail(IllegalArgumentException) {
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
logger.expected(iae.getMessage())
// Assert
assert msg =~ "Cannot decrypt without a valid IV"
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@ -397,7 +398,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -415,14 +416,13 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
INVALID_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
def msg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true)
}
logger.expected(msg)
// Initialize a cipher for
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
logger.expected(iae.getMessage())
// Assert
assert msg =~ "${keyLength} is not a valid key length for AES"
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
}
}
@ -435,12 +435,11 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
def msg = shouldFail(IllegalArgumentException) {
cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true))
// Assert
assert msg =~ "Encryption with an empty password is not supported"
assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"))
}
@Test
@ -463,10 +462,10 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
cipherProvider.parseSalt(FORMATTED_SALT, rawSalt, params)
// Assert
assert rawSalt == EXPECTED_RAW_SALT
assert params[0] == EXPECTED_MEMORY
assert params[1] == EXPECTED_PARALLELISM
assert params[2] == EXPECTED_ITERATIONS
assertArrayEquals(EXPECTED_RAW_SALT, rawSalt)
assertEquals(EXPECTED_MEMORY, params[0])
assertEquals(EXPECTED_PARALLELISM, params[1])
assertEquals(EXPECTED_ITERATIONS, params[2])
}
@Test
@ -485,7 +484,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Argon2 formatted salt: ${isValid}")
// Assert
assert !isValid
assertFalse(isValid)
}
@Test
@ -505,6 +504,6 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("rawSalt: ${Hex.encodeHexString(rawSalt)}")
// Assert
assert rawSalt == EXPECTED_RAW_SALT
assertArrayEquals(EXPECTED_RAW_SALT, rawSalt)
}
}

View File

@ -27,7 +27,12 @@ import org.slf4j.LoggerFactory
import java.nio.charset.StandardCharsets
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class Argon2SecureHasherTest {
private static final Logger logger = LoggerFactory.getLogger(Argon2SecureHasherTest.class)
@ -68,7 +73,7 @@ class Argon2SecureHasherTest {
}
// Assert
assert results.every { it == EXPECTED_HASH_HEX }
results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
}
@Test
@ -98,8 +103,8 @@ class Argon2SecureHasherTest {
}
// Assert
assert results.unique().size() == results.size()
assert results.every { it != EXPECTED_HASH_HEX }
assertTrue(results.unique().size() == results.size())
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
}
@Test
@ -143,20 +148,20 @@ class Argon2SecureHasherTest {
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert
assert staticSaltHash == EXPECTED_HASH_BYTES
assert arbitrarySaltHash == EXPECTED_HASH_BYTES
assert differentArbitrarySaltHash != EXPECTED_HASH_BYTES
assert differentSaltHash != EXPECTED_HASH_BYTES
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assert staticSaltHashHex == EXPECTED_HASH_HEX
assert arbitrarySaltHashHex == EXPECTED_HASH_HEX
assert differentArbitrarySaltHashHex != EXPECTED_HASH_HEX
assert differentSaltHashHex != EXPECTED_HASH_HEX
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assert staticSaltHashBase64 == EXPECTED_HASH_BASE64
assert arbitrarySaltHashBase64 == EXPECTED_HASH_BASE64
assert differentArbitrarySaltHashBase64 != EXPECTED_HASH_BASE64
assert differentSaltHashBase64 != EXPECTED_HASH_BASE64
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
}
@Test
@ -198,7 +203,7 @@ class Argon2SecureHasherTest {
logger.info("Generated hash: ${hashHex}")
// Assert
assert hashHex == EXPECTED_HASH_HEX
assertEquals(EXPECTED_HASH_HEX, hashHex)
}
@Test
@ -215,7 +220,7 @@ class Argon2SecureHasherTest {
logger.info("Generated hash: ${hashB64}")
// Assert
assert hashB64 == EXPECTED_HASH_B64
assertEquals(EXPECTED_HASH_B64, hashB64)
}
@Test
@ -243,8 +248,8 @@ class Argon2SecureHasherTest {
}
// Assert
assert hexResults.every { it == EXPECTED_HASH_HEX }
assert b64Results.every { it == EXPECTED_HASH_B64 }
hexResults.forEach(hexResult -> assertEquals(EXPECTED_HASH_HEX, hexResult))
b64Results.forEach(b64Result -> assertEquals(EXPECTED_HASH_B64, b64Result))
}
/**
@ -282,8 +287,8 @@ class Argon2SecureHasherTest {
// Assert
final long MIN_DURATION_NANOS = 500_000_000 // 500 ms
assert resultDurations.min() > MIN_DURATION_NANOS
assert resultDurations.sum() / testIterations > MIN_DURATION_NANOS
assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
}
@Test
@ -295,7 +300,7 @@ class Argon2SecureHasherTest {
boolean valid = Argon2SecureHasher.isHashLengthValid(hashLength)
// Assert
assert valid
assertTrue(valid)
}
@Test
@ -312,7 +317,7 @@ class Argon2SecureHasherTest {
// Assert
results.each { hashLength, isHashLengthValid ->
logger.info("For hashLength value ${hashLength}, hashLength is ${isHashLengthValid ? "valid" : "invalid"}")
assert !isHashLengthValid
assertFalse(isHashLengthValid)
}
}
@ -325,7 +330,7 @@ class Argon2SecureHasherTest {
boolean valid = Argon2SecureHasher.isMemorySizeValid(memory)
// Assert
assert valid
assertTrue(valid)
}
@Test
@ -342,7 +347,7 @@ class Argon2SecureHasherTest {
// Assert
results.each { memory, isMemorySizeValid ->
logger.info("For memory size ${memory}, memory is ${isMemorySizeValid ? "valid" : "invalid"}")
assert !isMemorySizeValid
assertFalse(isMemorySizeValid)
}
}
@ -355,7 +360,7 @@ class Argon2SecureHasherTest {
boolean valid = Argon2SecureHasher.isParallelismValid(parallelism)
// Assert
assert valid
assertTrue(valid)
}
@Test
@ -372,7 +377,7 @@ class Argon2SecureHasherTest {
// Assert
results.each { parallelism, isParallelismValid ->
logger.info("For parallelization factor ${parallelism}, parallelism is ${isParallelismValid ? "valid" : "invalid"}")
assert !isParallelismValid
assertFalse(isParallelismValid)
}
}
@ -385,7 +390,7 @@ class Argon2SecureHasherTest {
boolean valid = Argon2SecureHasher.isIterationsValid(iterations)
// Assert
assert valid
assertTrue(valid)
}
@Test
@ -402,7 +407,7 @@ class Argon2SecureHasherTest {
// Assert
results.each { iterations, isIterationsValid ->
logger.info("For iteration counts ${iterations}, iteration is ${isIterationsValid ? "valid" : "invalid"}")
assert !isIterationsValid
assertFalse(isIterationsValid)
}
}
@ -411,17 +416,11 @@ class Argon2SecureHasherTest {
// Arrange
def saltLengths = [0, 64]
// Act
def results = saltLengths.collect { saltLength ->
def isValid = new Argon2SecureHasher().isSaltLengthValid(saltLength)
[saltLength, isValid]
}
// Assert
results.each { saltLength, isSaltLengthValid ->
logger.info("For salt length ${saltLength}, saltLength is ${isSaltLengthValid ? "valid" : "invalid"}")
assert isSaltLengthValid
}
// Act and Assert
Argon2SecureHasher argon2SecureHasher = new Argon2SecureHasher()
saltLengths.forEach(saltLength -> {
assertTrue(argon2SecureHasher.isSaltLengthValid(saltLength))
})
}
@Test
@ -429,17 +428,9 @@ class Argon2SecureHasherTest {
// Arrange
def saltLengths = [-16, 4]
// Act
def results = saltLengths.collect { saltLength ->
def isValid = new Argon2SecureHasher().isSaltLengthValid(saltLength)
[saltLength, isValid]
}
// Assert
results.each { saltLength, isSaltLengthValid ->
logger.info("For salt length ${saltLength}, saltLength is ${isSaltLengthValid ? "valid" : "invalid"}")
assert !isSaltLengthValid
}
// Act and Assert
Argon2SecureHasher argon2SecureHasher = new Argon2SecureHasher()
saltLengths.forEach(saltLength -> assertFalse(argon2SecureHasher.isSaltLengthValid(saltLength)))
}
@Test
@ -460,7 +451,7 @@ class Argon2SecureHasherTest {
}
// Assert
assert results[16][0..15] != results[32][0..15]
assertFalse(Arrays.equals(Arrays.copyOf(results[16], 16), Arrays.copyOf(results[32], 16)))
// Demonstrates that internal hash truncation is not supported
// assert results.every { int k, byte[] v -> v[0..15] as byte[] == EXPECTED_HASH}
}

View File

@ -34,9 +34,10 @@ import java.nio.charset.StandardCharsets
import java.security.MessageDigest
import java.security.Security
import static groovy.test.GroovyAssert.shouldFail
import static org.junit.Assert.assertTrue
import static org.junit.jupiter.api.Assumptions.assumeTrue
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assertions.assertThrows
class BcryptCipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(BcryptCipherProviderGroovyTest.class)
@ -88,7 +89,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -118,7 +119,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -150,7 +151,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -171,8 +172,8 @@ class BcryptCipherProviderGroovyTest {
logger.info("Generated ${secureHasherCalculatedHash}")
// Assert
assert secureHasherCalculatedHash == EXPECTED_HASH
assert secureHasherCalculatedHash == EXPECTED_HASH
assertEquals(EXPECTED_HASH, secureHasherCalculatedHash)
assertEquals(EXPECTED_HASH, secureHasherCalculatedHash)
}
@Test
@ -217,8 +218,8 @@ class BcryptCipherProviderGroovyTest {
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes)
logger.info("Expected cipher text: ${Hex.encodeHexString(rubyCipherBytes)}")
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec)
assert rubyCipher.doFinal(rubyCipherBytes) == PLAINTEXT.bytes
assert rubyCipher.doFinal(cipherBytes) == PLAINTEXT.bytes
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(rubyCipherBytes))
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(cipherBytes))
logger.sanity("Decrypted external cipher text and generated cipher text successfully")
// Sanity for hash generation
@ -226,7 +227,7 @@ class BcryptCipherProviderGroovyTest {
logger.sanity("Salt from external: ${FULL_SALT}")
String generatedHash = new String(BCrypt.withDefaults().hash(WORK_FACTOR, BcryptCipherProvider.extractRawSalt(FULL_SALT), PASSWORD.bytes))
logger.sanity("Generated hash: ${generatedHash}")
assert generatedHash == FULL_HASH
assertEquals(FULL_HASH, generatedHash)
// Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, FULL_SALT.bytes, IV, DEFAULT_KEY_LENGTH, false)
@ -235,7 +236,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
private static byte[] customB64Decode(String input) {
@ -294,7 +295,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
@Test
@ -313,13 +314,12 @@ class BcryptCipherProviderGroovyTest {
INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}")
def msg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true))
logger.warn(iae.getMessage())
// Assert
assert msg =~ "The salt must be of the format \\\$2a\\\$10\\\$gUVbkVzp79H8YaCOsCVZNu\\. To generate a salt, use BcryptCipherProvider#generateSalt"
assertTrue(iae.getMessage().contains("The salt must be of the format \$2a\$10\$gUVbkVzp79H8YaCOsCVZNu. To generate a salt, use BcryptCipherProvider#generateSalt"))
}
}
@ -346,13 +346,12 @@ class BcryptCipherProviderGroovyTest {
// Two different errors -- one explaining the no-salt method is not supported, and the other for an empty byte[] passed
// Act
def msg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true))
logger.warn(iae.getMessage())
// Assert
assert msg =~ "The salt must be of the format .* To generate a salt, use BcryptCipherProvider#generateSalt"
assertTrue((iae.getMessage() =~ "The salt must be of the format .* To generate a salt, use BcryptCipherProvider#generateSalt").find())
}
@Test
@ -375,12 +374,11 @@ class BcryptCipherProviderGroovyTest {
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def msg = shouldFail(IllegalArgumentException) {
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
// Assert
assert msg =~ "Cannot decrypt without a valid IV"
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@ -414,7 +412,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -436,12 +434,11 @@ class BcryptCipherProviderGroovyTest {
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
def msg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
// Assert
assert msg =~ "${keyLength} is not a valid key length for AES"
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
}
}
@ -457,8 +454,9 @@ class BcryptCipherProviderGroovyTest {
logger.info("Salt: ${salt}")
// Assert
assert salt =~ /^\$2[axy]\$\d{2}\$/
assert salt.contains("\$${workFactor}\$")
assertTrue((salt =~ /^\$2[axy]\$\d{2}\$/).find())
assertTrue(salt.contains("\$" + workFactor + "\$"))
}
/**
@ -518,8 +516,8 @@ class BcryptCipherProviderGroovyTest {
logger.info("Verified: ${verificationRecovered}")
// Assert
assert PLAINTEXT == recovered
assert PLAINTEXT == verificationRecovered
assertEquals(PLAINTEXT, recovered)
assertEquals(PLAINTEXT, verificationRecovered)
}
}
@ -549,7 +547,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT == recovered
assertEquals(PLAINTEXT, recovered)
}
}
@ -566,27 +564,22 @@ class BcryptCipherProviderGroovyTest {
logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption
def encryptMsg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true)
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}")
}
logger.expected("Encrypt error: ${encryptMsg}")
IllegalArgumentException encryptIae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true))
logger.warn("Encrypt error: " + encryptIae.getMessage())
byte[] cipherBytes = PLAINTEXT.reverse().getBytes(StandardCharsets.UTF_8)
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def decryptMsg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, [0x00] * 16 as byte[], DEFAULT_KEY_LENGTH, false)
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8")
logger.info("Recovered: ${recovered}")
}
logger.expected("Decrypt error: ${decryptMsg}")
IllegalArgumentException decryptIae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, [0x00] * 16 as byte[], DEFAULT_KEY_LENGTH, false))
logger.warn("Decrypt error: " + decryptIae.getMessage())
// Assert
assert encryptMsg =~ "The salt must be of the format"
assert decryptMsg =~ "The salt must be of the format"
assertTrue(encryptIae.getMessage().contains("The salt must be of the format"))
assertTrue(decryptIae.getMessage().contains("The salt must be of the format"))
}
@Disabled("This test can be run on a specific machine to evaluate if the default work factor is sufficient")

View File

@ -26,7 +26,12 @@ import org.slf4j.LoggerFactory
import java.nio.charset.StandardCharsets
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class BcryptSecureHasherTest {
private static final Logger logger = LoggerFactory.getLogger(BcryptSecureHasher)
@ -62,7 +67,7 @@ class BcryptSecureHasherTest {
}
// Assert
assert results.every { it == EXPECTED_HASH_HEX }
results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
}
@Test
@ -90,8 +95,8 @@ class BcryptSecureHasherTest {
}
// Assert
assert results.unique().size() == results.size()
assert results.every { it != EXPECTED_HASH_HEX }
assertEquals(results.size(), results.unique().size())
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
}
@Test
@ -131,21 +136,20 @@ class BcryptSecureHasherTest {
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert
assert staticSaltHash == EXPECTED_HASH_BYTES
assert arbitrarySaltHash == EXPECTED_HASH_BYTES
assert differentArbitrarySaltHash != EXPECTED_HASH_BYTES
assert differentSaltHash != EXPECTED_HASH_BYTES
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assert staticSaltHashHex == EXPECTED_HASH_HEX
assert arbitrarySaltHashHex == EXPECTED_HASH_HEX
assert differentArbitrarySaltHashHex != EXPECTED_HASH_HEX
assert differentSaltHashHex != EXPECTED_HASH_HEX
assert staticSaltHashBase64 == EXPECTED_HASH_BASE64
assert arbitrarySaltHashBase64 == EXPECTED_HASH_BASE64
assert differentArbitrarySaltHashBase64 != EXPECTED_HASH_BASE64
assert differentSaltHashBase64 != EXPECTED_HASH_BASE64
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
}
@Test
@ -182,7 +186,7 @@ class BcryptSecureHasherTest {
logger.info("Generated hash: ${hashHex}")
// Assert
assert hashHex == EXPECTED_HASH_HEX
assertEquals(EXPECTED_HASH_HEX, hashHex)
}
@Test
@ -199,7 +203,7 @@ class BcryptSecureHasherTest {
logger.info("Generated hash: ${hashB64}")
// Assert
assert hashB64 == EXPECTED_HASH_BASE64
assertEquals(EXPECTED_HASH_BASE64, hashB64)
}
@Test
@ -227,8 +231,8 @@ class BcryptSecureHasherTest {
}
// Assert
assert hexResults.every { it == EXPECTED_HASH_HEX }
assert B64Results.every { it == EXPECTED_HASH_BASE64 }
hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result))
}
/**
@ -263,8 +267,8 @@ class BcryptSecureHasherTest {
// Assert
final long MIN_DURATION_NANOS = 75_000_000 // 75 ms
assert resultDurations.min() > MIN_DURATION_NANOS
assert resultDurations.sum() / testIterations > MIN_DURATION_NANOS
assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
}
@Test
@ -272,11 +276,8 @@ class BcryptSecureHasherTest {
// Arrange
final int cost = 14
// Act
boolean valid = BcryptSecureHasher.isCostValid(cost)
// Assert
assert valid
// Act and Assert
assertTrue(BcryptSecureHasher.isCostValid(cost))
}
@Test
@ -284,17 +285,8 @@ class BcryptSecureHasherTest {
// Arrange
def costFactors = [-8, 0, 40]
// Act
def results = costFactors.collect { costFactor ->
def isValid = BcryptSecureHasher.isCostValid(costFactor)
[costFactor, isValid]
}
// Assert
results.each { costFactor, isCostValid ->
logger.info("For cost factor ${costFactor}, cost is ${isCostValid ? "valid" : "invalid"}")
assert !isCostValid
}
// Act and Assert
costFactors.forEach(costFactor -> assertFalse(BcryptSecureHasher.isCostValid(costFactor)))
}
@Test
@ -302,16 +294,9 @@ class BcryptSecureHasherTest {
// Arrange
def saltLengths = [0, 16]
// Act
def results = saltLengths.collect { saltLength ->
def isValid = new BcryptSecureHasher().isSaltLengthValid(saltLength)
[saltLength, isValid]
}
// Assert
results.each { saltLength, isSaltLengthValid ->
assert { it == isSaltLengthValid }
}
// Act and Assert
BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher()
saltLengths.forEach(saltLength -> assertTrue(bcryptSecureHasher.isSaltLengthValid(saltLength)))
}
@Test
@ -319,17 +304,9 @@ class BcryptSecureHasherTest {
// Arrange
def saltLengths = [-8, 1]
// Act
def results = saltLengths.collect { saltLength ->
def isValid = new BcryptSecureHasher().isSaltLengthValid(saltLength)
[saltLength, isValid]
}
// Assert
results.each { saltLength, isSaltLengthValid ->
logger.info("For Salt Length value ${saltLength}, saltLength is ${isSaltLengthValid ? "valid" : "invalid"}")
assert !isSaltLengthValid
}
// Act and Assert
BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher()
saltLengths.forEach(saltLength -> assertFalse(bcryptSecureHasher.isSaltLengthValid(saltLength)))
}
@Test
@ -350,8 +327,8 @@ class BcryptSecureHasherTest {
logger.info("Converted (B64) ${convertedBase64} to (R64) ${convertedRadix64}")
// Assert
assert convertedBase64 == EXPECTED_MIME_B64
assert convertedRadix64 == INPUT_RADIX_64
assertEquals(EXPECTED_MIME_B64, convertedBase64)
assertEquals(INPUT_RADIX_64, convertedRadix64)
}
@Test
@ -372,8 +349,8 @@ class BcryptSecureHasherTest {
logger.info("Converted (B64) ${convertedBase64} to (R64) ${convertedRadix64}")
// Assert
assert convertedBase64 == EXPECTED_MIME_B64
assert convertedRadix64 == INPUT_RADIX_64
assertEquals(EXPECTED_MIME_B64, convertedBase64)
assertEquals(INPUT_RADIX_64, convertedRadix64)
}
}

View File

@ -1,64 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License") you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.apache.nifi.security.util.KeyDerivationFunction
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.security.Security
class CipherProviderFactoryGroovyTest extends GroovyTestCase {
private static final Logger logger = LoggerFactory.getLogger(CipherProviderFactoryGroovyTest.class)
private static final Map<KeyDerivationFunction, Class> EXPECTED_CIPHER_PROVIDERS = [
(KeyDerivationFunction.BCRYPT) : BcryptCipherProvider.class,
(KeyDerivationFunction.NIFI_LEGACY) : NiFiLegacyCipherProvider.class,
(KeyDerivationFunction.NONE) : AESKeyedCipherProvider.class,
(KeyDerivationFunction.OPENSSL_EVP_BYTES_TO_KEY): OpenSSLPKCS5CipherProvider.class,
(KeyDerivationFunction.PBKDF2) : PBKDF2CipherProvider.class,
(KeyDerivationFunction.SCRYPT) : ScryptCipherProvider.class,
(KeyDerivationFunction.ARGON2) : Argon2CipherProvider.class
]
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Test
void testGetCipherProviderShouldResolveRegisteredKDFs() {
// Arrange
// Act
KeyDerivationFunction.values().each { KeyDerivationFunction kdf ->
logger.info("Expected: ${kdf.kdfName} -> ${EXPECTED_CIPHER_PROVIDERS.get(kdf).simpleName}")
CipherProvider cp = CipherProviderFactory.getCipherProvider(kdf)
logger.info("Resolved: ${kdf.kdfName} -> ${cp.class.simpleName}")
// Assert
assert cp.class == (EXPECTED_CIPHER_PROVIDERS.get(kdf))
}
}
}

View File

@ -27,7 +27,12 @@ import org.slf4j.LoggerFactory
import java.security.Security
class CipherUtilityGroovyTest extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertTrue
class CipherUtilityGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(CipherUtilityGroovyTest.class)
// TripleDES must precede DES for automatic grouping precedence
@ -93,7 +98,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Extracted ${cipher} from ${algorithm}")
// Assert
assert EXPECTED_ALGORITHMS.get(cipher).contains(algorithm)
assertTrue(EXPECTED_ALGORITHMS.get(cipher).contains(algorithm))
}
}
@ -108,7 +113,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Extracted ${keyLength} from ${algorithm}")
// Assert
assert EXPECTED_ALGORITHMS.get(keyLength).contains(algorithm)
assertTrue(EXPECTED_ALGORITHMS.get(keyLength).contains(algorithm))
}
}
@ -122,7 +127,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${keyLength} for ${algorithm}")
// Assert
assert CipherUtility.isValidKeyLength(keyLength, CipherUtility.parseCipherFromAlgorithm(algorithm))
assertTrue(CipherUtility.isValidKeyLength(keyLength, CipherUtility.parseCipherFromAlgorithm(algorithm)))
}
}
}
@ -143,9 +148,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${invalidKeyLengths.join(", ")} for ${algorithm}")
// Assert
invalidKeyLengths.each { int invalidKeyLength ->
assert !CipherUtility.isValidKeyLength(invalidKeyLength, CipherUtility.parseCipherFromAlgorithm(algorithm))
}
invalidKeyLengths.forEach(invalidKeyLength -> assertFalse(CipherUtility.isValidKeyLength(invalidKeyLength, CipherUtility.parseCipherFromAlgorithm(algorithm))))
}
}
}
@ -160,7 +163,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${keyLength} for ${algorithm}")
// Assert
assert CipherUtility.isValidKeyLengthForAlgorithm(keyLength, algorithm)
assertTrue(CipherUtility.isValidKeyLengthForAlgorithm(keyLength, algorithm))
}
}
}
@ -181,9 +184,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${invalidKeyLengths.join(", ")} for ${algorithm}")
// Assert
invalidKeyLengths.each { int invalidKeyLength ->
assert !CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm)
}
invalidKeyLengths.forEach(invalidKeyLength -> assertFalse(CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm)))
}
}
@ -191,7 +192,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
String algorithm = "PBEWITHSHA256AND256BITAES-CBC-BC"
int invalidKeyLength = 192
logger.info("Checking ${invalidKeyLength} for ${algorithm}")
assert !CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm)
assertFalse(CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm))
}
@Test
@ -223,7 +224,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${algorithm} ${validKeySizes} against expected ${EXPECTED_KEY_SIZES}")
// Assert
assert validKeySizes == EXPECTED_KEY_SIZES
assertEquals(EXPECTED_KEY_SIZES, validKeySizes)
}
// Act
@ -235,7 +236,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${algorithm} ${validKeySizes} against expected ${EXPECTED_KEY_SIZES}")
// Assert
assert validKeySizes == EXPECTED_KEY_SIZES
assertEquals(EXPECTED_KEY_SIZES, validKeySizes)
}
}
@ -269,10 +270,10 @@ the License. You may obtain a copy of the License at
logger.info("Looking for ${Hex.encodeHexString(kafka)}; found at ${kafkaIndex}")
// Assert
assert apacheIndex == 16
assert softwareIndex == 23
assert asfIndex == 44
assert kafkaIndex == -1
assertEquals(16, apacheIndex)
assertEquals(23, softwareIndex)
assertEquals(44, asfIndex)
assertEquals(-1, kafkaIndex)
}
@Test
@ -285,7 +286,7 @@ the License. You may obtain a copy of the License at
String SCRYPT_SALT = ScryptCipherProvider.formatSaltForScrypt(PLAIN_SALT, 10, 1, 1)
// Act
def results = KeyDerivationFunction.values().findAll { !it.isStrongKDF() }.collectEntries { KeyDerivationFunction weakKdf ->
Map<Object, byte[]> results = KeyDerivationFunction.values().findAll { !it.isStrongKDF() }.collectEntries { KeyDerivationFunction weakKdf ->
[weakKdf, CipherUtility.extractRawSalt(PLAIN_SALT, weakKdf)]
}
@ -295,6 +296,6 @@ the License. You may obtain a copy of the License at
results.put(KeyDerivationFunction.PBKDF2, CipherUtility.extractRawSalt(PLAIN_SALT, KeyDerivationFunction.PBKDF2))
// Assert
assert results.every { k, v -> v == PLAIN_SALT }
results.values().forEach(v -> assertArrayEquals(PLAIN_SALT, v))
}
}

View File

@ -25,7 +25,10 @@ import org.slf4j.LoggerFactory
import java.security.Security
class HashAlgorithmTest extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertTrue
class HashAlgorithmTest {
private static final Logger logger = LoggerFactory.getLogger(HashAlgorithmTest.class)
@ -48,7 +51,7 @@ class HashAlgorithmTest extends GroovyTestCase {
logger.info("Broken algorithms: ${brokenAlgorithms}")
// Assert
assert brokenAlgorithms == [HashAlgorithm.MD2, HashAlgorithm.MD5, HashAlgorithm.SHA1]
assertEquals([HashAlgorithm.MD2, HashAlgorithm.MD5, HashAlgorithm.SHA1], brokenAlgorithms)
}
@Test
@ -62,11 +65,11 @@ class HashAlgorithmTest extends GroovyTestCase {
}
// Assert
assert descriptions.every {
it =~ /.* \(\d+ byte output\).*/
}
descriptions.forEach(description -> assertTrue((description =~ /.* \(\d+ byte output\).*/).find()) )
assert descriptions.findAll { it =~ "MD2|MD5|SHA-1" }.every { it =~ /\[WARNING/ }
descriptions.stream()
.filter(description -> (description =~ "MD2|MD5|SHA-1").find() )
.forEach(description -> assertTrue(description.contains("WARNING")))
}
@Test
@ -78,7 +81,7 @@ class HashAlgorithmTest extends GroovyTestCase {
logger.info("Blake2 algorithms: ${blake2Algorithms}")
// Assert
assert blake2Algorithms == [HashAlgorithm.BLAKE2_160, HashAlgorithm.BLAKE2_256, HashAlgorithm.BLAKE2_384, HashAlgorithm.BLAKE2_512]
assertEquals([HashAlgorithm.BLAKE2_160, HashAlgorithm.BLAKE2_256, HashAlgorithm.BLAKE2_384, HashAlgorithm.BLAKE2_512], blake2Algorithms)
}
@Test
@ -95,8 +98,7 @@ class HashAlgorithmTest extends GroovyTestCase {
HashAlgorithm found = HashAlgorithm.fromName(name)
// Assert
assert found instanceof HashAlgorithm
assert found.name == name.toUpperCase()
assertEquals(name.toUpperCase(), found.name)
}
}
}

View File

@ -29,7 +29,14 @@ import java.nio.charset.Charset
import java.nio.charset.StandardCharsets
import java.security.Security
class HashServiceTest extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertInstanceOf
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class HashServiceTest {
private static final Logger logger = LoggerFactory.getLogger(HashServiceTest.class)
@BeforeAll
@ -70,9 +77,9 @@ class HashServiceTest extends GroovyTestCase {
// Assert
if (result instanceof byte[]) {
assert result == EXPECTED_HASH_BYTES
assertArrayEquals(EXPECTED_HASH_BYTES, result)
} else {
assert result == EXPECTED_HASH
assertEquals(EXPECTED_HASH, result)
}
}
}
@ -90,11 +97,12 @@ class HashServiceTest extends GroovyTestCase {
logger.info("UTF-16: ${utf16Hash}")
// Assert
assert utf8Hash != utf16Hash
assertNotEquals(utf8Hash, utf16Hash)
}
/**
* This test ensures that the service properly handles UTF-16 encoded data to return it without the Big Endian Byte Order Mark (BOM). Java treats UTF-16 encoded data without a BOM as Big Endian by default on decoding, but when <em>encoding</em>, it inserts a BE BOM in the data.
* This test ensures that the service properly handles UTF-16 encoded data to return it without
* the Big Endian Byte Order Mark (BOM). Java treats UTF-16 encoded data without a BOM as Big Endian by default on decoding, but when <em>encoding</em>, it inserts a BE BOM in the data.
*
* Examples:
*
@ -138,7 +146,7 @@ class HashServiceTest extends GroovyTestCase {
logger.info("${algorithm.name}(${KNOWN_VALUE}, ${charset.name().padLeft(9)}) = ${hash}")
// Assert
assert hash == EXPECTED_SHA_256_HASHES[translateStringToMapKey(charset.name())]
assertEquals(EXPECTED_SHA_256_HASHES[translateStringToMapKey(charset.name())], hash)
}
}
@ -162,16 +170,15 @@ class HashServiceTest extends GroovyTestCase {
logger.info("Implicit UTF-8 bytes: ${implicitUTF8HashBytesDefault}")
// Assert
assert explicitUTF8Hash == implicitUTF8Hash
assert explicitUTF8HashBytes == implicitUTF8HashBytes
assert explicitUTF8HashBytes == implicitUTF8HashBytesDefault
assertEquals(explicitUTF8Hash, implicitUTF8Hash)
assertArrayEquals(explicitUTF8HashBytes, implicitUTF8HashBytes)
assertArrayEquals(explicitUTF8HashBytes, implicitUTF8HashBytesDefault)
}
@Test
void testShouldRejectNullAlgorithm() {
// Arrange
final String KNOWN_VALUE = "apachenifi"
Closure threeArgString = { -> HashService.hashValue(null, KNOWN_VALUE, StandardCharsets.UTF_8) }
Closure twoArgString = { -> HashService.hashValue(null, KNOWN_VALUE) }
Closure threeArgStringRaw = { -> HashService.hashValueRaw(null, KNOWN_VALUE, StandardCharsets.UTF_8) }
@ -186,15 +193,10 @@ class HashServiceTest extends GroovyTestCase {
]
// Act
scenarios.each { String name, Closure closure ->
def msg = shouldFail(IllegalArgumentException) {
closure.call()
}
logger.expected("${name.padLeft(20)}: ${msg}")
// Assert
assert msg =~ "The hash algorithm cannot be null"
}
scenarios.entrySet().forEach(entry -> {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> entry.getValue().call())
assertTrue(iae.message.contains("The hash algorithm cannot be null"))
})
}
@Test
@ -216,15 +218,10 @@ class HashServiceTest extends GroovyTestCase {
]
// Act
scenarios.each { String name, Closure closure ->
def msg = shouldFail(IllegalArgumentException) {
closure.call()
}
logger.expected("${name.padLeft(20)}: ${msg}")
// Assert
assert msg =~ "The value cannot be null"
}
scenarios.entrySet().forEach(entry -> {
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> entry.getValue().call())
assertTrue(iae.message.contains("The value cannot be null"))
})
}
@Test
@ -266,7 +263,7 @@ class HashServiceTest extends GroovyTestCase {
// Assert
generatedHashes.each { String algorithmName, String hash ->
String key = translateStringToMapKey(algorithmName)
assert EXPECTED_HASHES[key] == hash
assertEquals(EXPECTED_HASHES[key], hash)
}
}
@ -309,7 +306,7 @@ class HashServiceTest extends GroovyTestCase {
// Assert
generatedHashes.each { String algorithmName, String hash ->
String key = translateStringToMapKey(algorithmName)
assert EXPECTED_HASHES[key] == hash
assertEquals(EXPECTED_HASHES[key], hash)
}
}
@ -323,14 +320,14 @@ class HashServiceTest extends GroovyTestCase {
def allowableValues = HashService.buildHashAlgorithmAllowableValues()
// Assert
assert allowableValues instanceof AllowableValue[]
assertInstanceOf(AllowableValue[].class, allowableValues)
def valuesList = allowableValues as List<AllowableValue>
assert valuesList.size() == EXPECTED_ALGORITHMS.size()
assertEquals(EXPECTED_ALGORITHMS.size(), valuesList.size())
EXPECTED_ALGORITHMS.each { HashAlgorithm expectedAlgorithm ->
def matchingValue = valuesList.find { it.value == expectedAlgorithm.name }
assert matchingValue.displayName == expectedAlgorithm.name
assert matchingValue.description == expectedAlgorithm.buildAllowableValueDescription()
assertEquals(expectedAlgorithm.name, matchingValue.displayName)
assertEquals(expectedAlgorithm.buildAllowableValueDescription(), matchingValue.description)
}
}
@ -347,20 +344,21 @@ class HashServiceTest extends GroovyTestCase {
]
logger.info("The consistent list of character sets available [${EXPECTED_CHARACTER_SETS.size()}]: \n${EXPECTED_CHARACTER_SETS.collect { "\t${it.name()}" }.join("\n")}")
def expectedDescriptions = ["UTF-16": "This character set normally decodes using an optional BOM at the beginning of the data but encodes by inserting a BE BOM. For hashing, it will be replaced with UTF-16BE. "]
def expectedDescriptions =
["UTF-16": "This character set normally decodes using an optional BOM at the beginning of the data but encodes by inserting a BE BOM. For hashing, it will be replaced with UTF-16BE. "]
// Act
def allowableValues = HashService.buildCharacterSetAllowableValues()
// Assert
assert allowableValues instanceof AllowableValue[]
assertInstanceOf(AllowableValue[].class, allowableValues)
def valuesList = allowableValues as List<AllowableValue>
assert valuesList.size() == EXPECTED_CHARACTER_SETS.size()
assertEquals(EXPECTED_CHARACTER_SETS.size(), valuesList.size())
EXPECTED_CHARACTER_SETS.each { Charset charset ->
def matchingValue = valuesList.find { it.value == charset.name() }
assert matchingValue.displayName == charset.name()
assert matchingValue.description == (expectedDescriptions[charset.name()] ?: charset.displayName())
assertEquals(charset.name(), matchingValue.displayName)
assertEquals((expectedDescriptions[charset.name()] ?: charset.displayName()), matchingValue.description)
}
}
@ -410,7 +408,7 @@ class HashServiceTest extends GroovyTestCase {
// Assert
generatedHashes.each { String algorithmName, String hash ->
String key = translateStringToMapKey(algorithmName)
assert EXPECTED_HASHES[key] == hash
assertEquals(EXPECTED_HASHES[key], hash)
}
}

View File

@ -21,7 +21,6 @@ import org.apache.nifi.security.util.EncryptionMethod
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import org.slf4j.Logger
import org.slf4j.LoggerFactory
@ -32,6 +31,8 @@ import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.PBEParameterSpec
import java.security.Security
import static org.junit.jupiter.api.Assertions.assertEquals
class NiFiLegacyCipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(NiFiLegacyCipherProviderGroovyTest.class)
@ -100,7 +101,7 @@ class NiFiLegacyCipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
}
@ -130,7 +131,7 @@ class NiFiLegacyCipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
}
@ -165,7 +166,7 @@ class NiFiLegacyCipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
}
@ -199,7 +200,7 @@ class NiFiLegacyCipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
}
@ -230,7 +231,7 @@ class NiFiLegacyCipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
}
}

View File

@ -31,8 +31,11 @@ import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.PBEParameterSpec
import java.security.Security
import static groovy.test.GroovyAssert.shouldFail
import static org.junit.Assert.fail
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assertions.fail
class OpenSSLPKCS5CipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(OpenSSLPKCS5CipherProviderGroovyTest.class)
@ -99,7 +102,7 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
}
@ -128,7 +131,7 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
}
@ -162,7 +165,7 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
}
@ -196,7 +199,7 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
}
@ -227,7 +230,7 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8")
// Assert
assert plaintext.equals(recovered)
assertEquals(plaintext, recovered)
}
}
@ -242,12 +245,11 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
// Act
logger.info("Using algorithm: null")
def msg = shouldFail(IllegalArgumentException) {
Cipher providedCipher = cipherProvider.getCipher(null, PASSWORD, SALT, false)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(null, PASSWORD, SALT, false))
// Assert
assert msg =~ "The encryption method must be specified"
assertTrue(iae.getMessage().contains("The encryption method must be specified"))
}
@Test
@ -261,12 +263,11 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
// Act
logger.info("Using algorithm: ${encryptionMethod}")
def msg = shouldFail(IllegalArgumentException) {
Cipher providedCipher = cipherProvider.getCipher(encryptionMethod, "", SALT, false)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, "", SALT, false))
// Assert
assert msg =~ "Encryption with an empty password is not supported"
assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"))
}
@Test
@ -280,13 +281,11 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
// Act
logger.info("Using algorithm: ${encryptionMethod}")
def msg = shouldFail(IllegalArgumentException) {
Cipher providedCipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, false)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, false))
// Assert
assert msg =~ "Salt must be 8 bytes US-ASCII encoded"
assertTrue(iae.getMessage().contains("Salt must be 8 bytes US-ASCII encoded"))
}
@Test
@ -299,7 +298,9 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
logger.info("Checking salt ${Hex.encodeHexString(salt)}")
// Assert
assert salt.length == cipherProvider.getDefaultSaltLength()
assert salt != [(0x00 as byte) * cipherProvider.defaultSaltLength]
assertEquals(cipherProvider.getDefaultSaltLength(), salt.length)
byte [] notExpected = new byte [cipherProvider.defaultSaltLength]
Arrays.fill(notExpected, 0x00 as byte)
assertFalse(Arrays.equals(notExpected, salt))
}
}

View File

@ -28,8 +28,11 @@ import org.slf4j.LoggerFactory
import javax.crypto.Cipher
import java.security.Security
import static groovy.test.GroovyAssert.shouldFail
import static org.junit.Assert.assertTrue
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class PBKDF2CipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(PBKDF2CipherProviderGroovyTest.class)
@ -85,7 +88,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -108,12 +111,11 @@ class PBKDF2CipherProviderGroovyTest {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, true)
// Decrypt should fail
def msg = shouldFail(IllegalArgumentException) {
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false))
// Assert
assert msg =~ "Cannot decrypt without a valid IV"
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@ -143,7 +145,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -174,7 +176,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -192,12 +194,11 @@ class PBKDF2CipherProviderGroovyTest {
// Act
logger.info("Using PRF ${prf}")
def msg = shouldFail(IllegalArgumentException) {
cipherProvider = new PBKDF2CipherProvider(prf, TEST_ITERATION_COUNT)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> new PBKDF2CipherProvider(prf, TEST_ITERATION_COUNT))
// Assert
assert msg =~ "Cannot resolve empty PRF"
assertTrue(iae.getMessage().contains("Cannot resolve empty PRF"))
}
@Test
@ -234,7 +235,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
@Test
@ -270,7 +271,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -300,7 +301,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
@Test
@ -324,15 +325,15 @@ class PBKDF2CipherProviderGroovyTest {
byte[] sha512CipherBytes = sha512Cipher.doFinal(PLAINTEXT.bytes)
// Assert
assert sha512CipherBytes != sha256CipherBytes
assertFalse(Arrays.equals(sha512CipherBytes, sha256CipherBytes))
Cipher sha256DecryptCipher = sha256CP.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] sha256RecoveredBytes = sha256DecryptCipher.doFinal(sha256CipherBytes)
assert sha256RecoveredBytes == PLAINTEXT.bytes
assertArrayEquals(PLAINTEXT.bytes, sha256RecoveredBytes)
Cipher sha512DecryptCipher = sha512CP.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] sha512RecoveredBytes = sha512DecryptCipher.doFinal(sha512CipherBytes)
assert sha512RecoveredBytes == PLAINTEXT.bytes
assertArrayEquals(PLAINTEXT.bytes, sha512RecoveredBytes)
}
@Test
@ -355,12 +356,11 @@ class PBKDF2CipherProviderGroovyTest {
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def msg = shouldFail(IllegalArgumentException) {
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
// Assert
assert msg =~ "Cannot decrypt without a valid IV"
assertTrue(iae.getMessage().contains( "Cannot decrypt without a valid IV"))
}
}
@ -380,12 +380,11 @@ class PBKDF2CipherProviderGroovyTest {
INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}")
def msg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt?.bytes, DEFAULT_KEY_LENGTH, true)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt?.bytes, DEFAULT_KEY_LENGTH, true))
// Assert
assert msg =~ "The salt must be at least 16 bytes\\. To generate a salt, use PBKDF2CipherProvider#generateSalt"
assertTrue(iae.getMessage().contains("The salt must be at least 16 bytes. To generate a salt, use PBKDF2CipherProvider#generateSalt"))
}
}
@ -419,7 +418,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -441,12 +440,11 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
def msg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
// Assert
assert msg =~ "${keyLength} is not a valid key length for AES"
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
}
}
@ -537,7 +535,9 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Checking salt ${Hex.encodeHexString(salt)}")
// Assert
assert salt.length == 16
assert salt != [(0x00 as byte) * 16]
assertEquals(16,salt.length )
byte [] notExpected = new byte[16]
Arrays.fill(notExpected, 0x00 as byte)
assertFalse(Arrays.equals(notExpected, salt))
}
}

View File

@ -21,8 +21,15 @@ import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import java.nio.charset.StandardCharsets
import java.util.stream.Collectors
import java.util.stream.Stream
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class PBKDF2SecureHasherTest {
@ -31,25 +38,21 @@ class PBKDF2SecureHasherTest {
// Arrange
int cost = 10_000
int dkLength = 32
int testIterations = 10
byte[] inputBytes = "This is a sensitive value".bytes
final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511"
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(cost, dkLength)
def results = []
// Act
testIterations.times { int i ->
byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes)
String hashHex = new String(Hex.encode(hash))
results << hashHex
}
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(cost, dkLength)
List<String> results = Stream.iterate(0, n -> n + 1)
.limit(10)
.map(iteration -> {
byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes)
return new String(Hex.encode(hash))
})
.collect(Collectors.toList())
// Assert
assert results.every { it == EXPECTED_HASH_HEX }
results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
}
@Test
@ -59,26 +62,22 @@ class PBKDF2SecureHasherTest {
int cost = 10_000
int saltLength = 16
int dkLength = 32
int testIterations = 10
byte[] inputBytes = "This is a sensitive value".bytes
final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511"
//Act
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(prf, cost, saltLength, dkLength)
def results = []
// Act
testIterations.times { int i ->
byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes)
String hashHex = Hex.encode(hash)
results << hashHex
}
List<String> results = Stream.iterate(0, n -> n + 1)
.limit(10)
.map(iteration -> {
byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes)
return new String(Hex.encode(hash))
})
.collect(Collectors.toList())
// Assert
assert results.unique().size() == results.size()
assert results.every { it != EXPECTED_HASH_HEX }
assertEquals(results.unique().size(), results.size())
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
}
@Test
@ -119,20 +118,20 @@ class PBKDF2SecureHasherTest {
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert
assert staticSaltHash == EXPECTED_HASH_BYTES
assert arbitrarySaltHash == EXPECTED_HASH_BYTES
assert differentArbitrarySaltHash != EXPECTED_HASH_BYTES
assert differentSaltHash != EXPECTED_HASH_BYTES
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assert staticSaltHashHex == EXPECTED_HASH_HEX
assert arbitrarySaltHashHex == EXPECTED_HASH_HEX
assert differentArbitrarySaltHashHex != EXPECTED_HASH_HEX
assert differentSaltHashHex != EXPECTED_HASH_HEX
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assert staticSaltHashBase64 == EXPECTED_HASH_BASE64
assert arbitrarySaltHashBase64 == EXPECTED_HASH_BASE64
assert differentArbitrarySaltHashBase64 != EXPECTED_HASH_BASE64
assert differentSaltHashBase64 != EXPECTED_HASH_BASE64
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
}
@Test
@ -169,7 +168,7 @@ class PBKDF2SecureHasherTest {
String hashHex = pbkdf2SecureHasher.hashHex(input)
// Assert
assert hashHex == EXPECTED_HASH_HEX
assertEquals(EXPECTED_HASH_HEX, hashHex)
}
@Test
@ -185,7 +184,7 @@ class PBKDF2SecureHasherTest {
String hashB64 = pbkdf2SecureHasher.hashBase64(input)
// Assert
assert hashB64 == EXPECTED_HASH_BASE64
assertEquals(EXPECTED_HASH_BASE64, hashB64)
}
@Test
@ -196,23 +195,18 @@ class PBKDF2SecureHasherTest {
final String EXPECTED_HASH_HEX = "7f2d8d8c7aaa45471f6c05a8edfe0a3f75fe01478cc965c5dce664e2ac6f5d0a"
final String EXPECTED_HASH_BASE64 = "fy2NjHqqRUcfbAWo7f4KP3X+AUeMyWXF3OZk4qxvXQo"
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher()
def hexResults = []
def B64Results = []
// Act
inputs.each { String input ->
String hashHex = pbkdf2SecureHasher.hashHex(input)
hexResults << hashHex
String hashB64 = pbkdf2SecureHasher.hashBase64(input)
B64Results << hashB64
}
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher()
List<String> hexResults = inputs.stream()
.map(input -> pbkdf2SecureHasher.hashHex(input))
.collect(Collectors.toList())
List<String> B64Results = inputs.stream()
.map(input -> pbkdf2SecureHasher.hashBase64(input))
.collect(Collectors.toList())
// Assert
assert hexResults.every { it == EXPECTED_HASH_HEX }
assert B64Results.every { it == EXPECTED_HASH_BASE64 }
hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result))
}
/**
@ -246,8 +240,8 @@ class PBKDF2SecureHasherTest {
// Assert
final long MIN_DURATION_NANOS = 75_000_000 // 75 ms
assert resultDurations.min() > MIN_DURATION_NANOS
assert resultDurations.sum() / testIterations > MIN_DURATION_NANOS
assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
}
@Test
@ -262,99 +256,67 @@ class PBKDF2SecureHasherTest {
}
// Assert
assert results.every()
assertTrue(results.every())
}
@Test
void testShouldFailIterationCountBoundary() throws Exception {
// Arrange
def invalidIterationCounts = [-1, 0, Integer.MAX_VALUE + 1]
List<Integer> invalidIterationCounts = [-1, 0, Integer.MAX_VALUE + 1]
// Act
def results = invalidIterationCounts.collect { i ->
boolean valid = PBKDF2SecureHasher.isIterationCountValid(i)
valid
}
// Assert
results.each { valid ->
assert !valid
}
// Act and Assert
invalidIterationCounts.forEach(i -> assertFalse(PBKDF2SecureHasher.isIterationCountValid(i)))
}
@Test
void testShouldVerifyDKLengthBoundary() throws Exception {
// Arrange
def validHLengths = [32, 64]
List<Integer> validHLengths = [32, 64]
// 1 and MAX_VALUE are the length boundaries, inclusive
def validDKLengths = [1, 1000, 1_000_000, Integer.MAX_VALUE]
List<Integer> validDKLengths = [1, 1000, 1_000_000, Integer.MAX_VALUE]
// Act
def results = validHLengths.collectEntries { int hLen ->
def dkResults = validDKLengths.collect { int dkLength ->
boolean valid = PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength)
valid
}
[hLen, dkResults]
}
// Assert
results.each { int hLen, def dkResults ->
assert dkResults.every()
}
// Act and Assert
validHLengths.forEach(hLen -> {
validDKLengths.forEach(dkLength -> {
assertTrue(PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength))
})
})
}
@Test
void testShouldFailDKLengthBoundary() throws Exception {
// Arrange
def validHLengths = [32, 64]
List<Integer> validHLengths = [32, 64]
// MAX_VALUE + 1 will become MIN_VALUE because of signed integer math
def invalidDKLengths = [-1, 0, Integer.MAX_VALUE + 1, new Integer(Integer.MAX_VALUE * 2 - 1)]
List<Integer> invalidDKLengths = [-1, 0, Integer.MAX_VALUE + 1, new Integer(Integer.MAX_VALUE * 2 - 1)]
// Act
def results = validHLengths.collectEntries { int hLen ->
def dkResults = invalidDKLengths.collect { int dkLength ->
boolean valid = PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength)
valid
}
[hLen, dkResults]
}
// Assert
results.each { int hLen, def dkResults ->
assert dkResults.every { boolean valid -> !valid }
}
// Act and Assert
validHLengths.forEach(hLen -> {
invalidDKLengths.forEach(dkLength -> {
assertFalse(PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength))
})
})
}
@Test
void testShouldVerifySaltLengthBoundary() throws Exception {
// Arrange
def saltLengths = [0, 16, 64]
List<Integer> saltLengths = [0, 16, 64]
// Act
def results = saltLengths.collect { saltLength ->
def isValid = new PBKDF2SecureHasher().isSaltLengthValid(saltLength)
isValid
}
// Assert
assert results.every()
// Act and Assert
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher()
saltLengths.forEach(saltLength -> assertTrue(pbkdf2SecureHasher.isSaltLengthValid(saltLength)))
}
@Test
void testShouldFailSaltLengthBoundary() throws Exception {
// Arrange
def saltLengths = [-8, 1, Integer.MAX_VALUE + 1]
List<Integer> saltLengths = [-8, 1, Integer.MAX_VALUE + 1]
// Act
def results = saltLengths.collect { saltLength ->
def isValid = new PBKDF2SecureHasher().isSaltLengthValid(saltLength)
isValid
}
// Assert
results.each { assert !it }
// Act and Assert
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher()
saltLengths.forEach(saltLength -> assertFalse(pbkdf2SecureHasher.isSaltLengthValid(saltLength)))
}
}

View File

@ -35,8 +35,12 @@ import javax.crypto.spec.SecretKeySpec
import java.security.SecureRandom
import java.security.Security
import static groovy.test.GroovyAssert.shouldFail
import static org.junit.Assert.assertTrue
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class ScryptCipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(ScryptCipherProviderGroovyTest.class)
@ -94,7 +98,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -122,7 +126,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -152,7 +156,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -189,9 +193,9 @@ class ScryptCipherProviderGroovyTest {
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes)
logger.sanity("Created cipher text: ${Hex.encodeHexString(rubyCipherBytes)}")
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec)
assert rubyCipher.doFinal(rubyCipherBytes) == PLAINTEXT.bytes
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(rubyCipherBytes))
logger.sanity("Decrypted generated cipher text successfully")
assert rubyCipher.doFinal(cipherBytes) == PLAINTEXT.bytes
assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(cipherBytes))
logger.sanity("Decrypted external cipher text successfully")
// n$r$p$hex_salt_SL$hex_hash_HL
@ -214,7 +218,7 @@ class ScryptCipherProviderGroovyTest {
// Convert hash from hex to Base64
String base64Hash = CipherUtility.encodeBase64NoPadding(Hex.decodeHex(hashHex as char[]))
logger.info("Converted hash from hex ${hashHex} to Base64 ${base64Hash}")
assert Hex.encodeHexString(Base64.decodeBase64(base64Hash)) == hashHex
assertEquals(hashHex, Hex.encodeHexString(Base64.decodeBase64(base64Hash)))
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
logger.info("External cipher text: ${CIPHER_TEXT} ${cipherBytes.length}")
@ -226,7 +230,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
@Test
@ -272,7 +276,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
@Test
@ -290,13 +294,12 @@ class ScryptCipherProviderGroovyTest {
INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}")
def msg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true))
logger.warn(iae.getMessage())
// Assert
assert msg =~ LENGTH_MESSAGE
assertTrue(iae.getMessage().contains(LENGTH_MESSAGE))
}
}
@ -317,7 +320,7 @@ class ScryptCipherProviderGroovyTest {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true)
// Assert
assert cipher
assertNotNull(cipher)
}
}
@ -330,13 +333,12 @@ class ScryptCipherProviderGroovyTest {
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act
def msg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true))
logger.warn(iae.getMessage())
// Assert
assert msg =~ "The salt cannot be empty\\. To generate a salt, use ScryptCipherProvider#generateSalt"
assertTrue(iae.getMessage().contains("The salt cannot be empty. To generate a salt, use ScryptCipherProvider#generateSalt"))
}
@Test
@ -357,13 +359,12 @@ class ScryptCipherProviderGroovyTest {
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def msg = shouldFail(IllegalArgumentException) {
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
logger.warn(iae.getMessage())
// Assert
assert msg =~ "Cannot decrypt without a valid IV"
assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
}
}
@ -394,7 +395,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}")
// Assert
assert PLAINTEXT.equals(recovered)
assertEquals(PLAINTEXT, recovered)
}
}
@ -415,13 +416,12 @@ class ScryptCipherProviderGroovyTest {
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption
def msg = shouldFail(IllegalArgumentException) {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
logger.warn(iae.getMessage())
// Assert
assert msg =~ "${keyLength} is not a valid key length for AES"
assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
}
}
@ -434,12 +434,11 @@ class ScryptCipherProviderGroovyTest {
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act
def msg = shouldFail(IllegalArgumentException) {
cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
()-> cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true))
// Assert
assert msg =~ "Encryption with an empty password is not supported"
assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"))
}
@Test
@ -455,9 +454,9 @@ class ScryptCipherProviderGroovyTest {
logger.info("Salt: ${salt}")
// Assert
assert salt =~ "^(?i)\\\$s0\\\$[a-f0-9]{5,16}\\\$"
assertTrue((salt =~ "^(?i)\\\$s0\\\$[a-f0-9]{5,16}\\\$").find())
String params = Scrypt.encodeParams(n, r, p)
assert salt.contains("\$${params}\$")
assertTrue(salt.contains("\$${params}\$"))
}
@Test
@ -480,10 +479,10 @@ class ScryptCipherProviderGroovyTest {
cipherProvider.parseSalt(FORMATTED_SALT, rawSalt, params)
// Assert
assert rawSalt == EXPECTED_RAW_SALT
assert params[0] == EXPECTED_N
assert params[1] == EXPECTED_R
assert params[2] == EXPECTED_P
assertArrayEquals(EXPECTED_RAW_SALT, rawSalt)
assertEquals(EXPECTED_N, params[0])
assertEquals(EXPECTED_R, params[1])
assertEquals(EXPECTED_P, params[2])
}
@Test
@ -496,7 +495,7 @@ class ScryptCipherProviderGroovyTest {
boolean valid = ScryptCipherProvider.isPValid(r, p)
// Assert
assert valid
assertTrue(valid)
}
@Test
@ -504,19 +503,12 @@ class ScryptCipherProviderGroovyTest {
// Arrange
// The p upper bound is calculated with the formula below, when r = 8:
// pBoundary = ((Math.pow(2,32))-1) * (32.0/(r * 128)), where pBoundary = 134217727.96875;
Map costParameters = [8:134217729, 128:8388608, 4096: 0]
Map<Integer, Integer> costParameters = [8:134217729, 128:8388608, 4096: 0]
// Act
def results = costParameters.collectEntries { r, p ->
def isValid = ScryptCipherProvider.isPValid(r, p)
[r, isValid]
}
// Assert
results.each { r, isPValid ->
logger.info("For r ${r}, p is ${isPValid}")
assert !isPValid
}
// Act and Assert
costParameters.entrySet().forEach(entry -> {
assertFalse(ScryptCipherProvider.isPValid(entry.getKey(), entry.getValue()))
})
}
@Test
@ -528,7 +520,7 @@ class ScryptCipherProviderGroovyTest {
boolean valid = ScryptCipherProvider.isRValid(r)
// Assert
assert valid
assertTrue(valid)
}
@Test
@ -540,7 +532,7 @@ class ScryptCipherProviderGroovyTest {
boolean valid = ScryptCipherProvider.isRValid(r)
// Assert
assert !valid
assertFalse(valid)
}
@Test
@ -554,7 +546,7 @@ class ScryptCipherProviderGroovyTest {
ScryptCipherProvider testCipherProvider = new ScryptCipherProvider(n, r, p)
// Assert
assert testCipherProvider
assertNotNull(testCipherProvider)
}
@Test
@ -565,13 +557,12 @@ class ScryptCipherProviderGroovyTest {
final int p = 0
// Act
def msg = shouldFail(IllegalArgumentException) {
ScryptCipherProvider testCipherProvider = new ScryptCipherProvider(n, r, p)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> new ScryptCipherProvider(n, r, p))
logger.warn(iae.getMessage())
// Assert
assert msg =~ "Invalid p value exceeds p boundary"
assertTrue(iae.getMessage().contains("Invalid p value exceeds p boundary"))
}
@Test
@ -582,13 +573,12 @@ class ScryptCipherProviderGroovyTest {
final int p = 0
// Act
def msg = shouldFail(IllegalArgumentException) {
ScryptCipherProvider testCipherProvider = new ScryptCipherProvider(n, r, p)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> new ScryptCipherProvider(n, r, p))
logger.warn(iae.getMessage())
// Assert
assert msg =~ "Invalid r value; must be greater than 0"
assertTrue(iae.getMessage().contains("Invalid r value; must be greater than 0"))
}
@Test
@ -601,7 +591,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Is Scrypt salt: ${isScryptSalt}")
// Assert
assert isScryptSalt
assertTrue(isScryptSalt)
}
@EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true",
@ -624,7 +614,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Determined minimum safe parameters to be N=${minimumN}, r=${minimumR}, p=${minimumP}")
// Assert
assertTrue("The default parameters for ScryptCipherProvider are too weak. Please update the default values to a stronger level.", n >= minimumN)
assertTrue(n >= minimumN, "The default parameters for ScryptCipherProvider are too weak. Please update the default values to a stronger level.")
}
/**
@ -644,7 +634,7 @@ class ScryptCipherProviderGroovyTest {
int n = 2**4
int dkLen = 128
assert Scrypt.calculateExpectedMemory(n, r, p) <= maxHeapSize
assertTrue(Scrypt.calculateExpectedMemory(n, r, p) <= maxHeapSize)
byte[] salt = new byte[Scrypt.defaultSaltLength]
new SecureRandom().nextBytes(salt)

View File

@ -22,7 +22,12 @@ import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import java.nio.charset.StandardCharsets
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class ScryptSecureHasherTest {
@ -51,7 +56,7 @@ class ScryptSecureHasherTest {
}
// Assert
assert results.every { it == EXPECTED_HASH_HEX }
results.forEach( result -> assertEquals(EXPECTED_HASH_HEX, result))
}
@Test
@ -79,8 +84,8 @@ class ScryptSecureHasherTest {
}
// Assert
assert results.unique().size() == results.size()
assert results.every { it != EXPECTED_HASH_HEX }
assertTrue(results.unique().size() == results.size())
results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
}
@Test
@ -122,20 +127,20 @@ class ScryptSecureHasherTest {
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert
assert staticSaltHash == EXPECTED_HASH_BYTES
assert arbitrarySaltHash == EXPECTED_HASH_BYTES
assert differentArbitrarySaltHash != EXPECTED_HASH_BYTES
assert differentSaltHash != EXPECTED_HASH_BYTES
assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assert staticSaltHashHex == EXPECTED_HASH_HEX
assert arbitrarySaltHashHex == EXPECTED_HASH_HEX
assert differentArbitrarySaltHashHex != EXPECTED_HASH_HEX
assert differentSaltHashHex != EXPECTED_HASH_HEX
assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assert staticSaltHashBase64 == EXPECTED_HASH_BASE64
assert arbitrarySaltHashBase64 == EXPECTED_HASH_BASE64
assert differentArbitrarySaltHashBase64 != EXPECTED_HASH_BASE64
assert differentSaltHashBase64 != EXPECTED_HASH_BASE64
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
}
@Test
@ -172,7 +177,7 @@ class ScryptSecureHasherTest {
String hashHex = scryptSH.hashHex(input)
// Assert
assert hashHex == EXPECTED_HASH_HEX
assertEquals(EXPECTED_HASH_HEX, hashHex)
}
@Test
@ -188,7 +193,7 @@ class ScryptSecureHasherTest {
String hashB64 = scryptSH.hashBase64(input)
// Assert
assert hashB64 == EXPECTED_HASH_BASE64
assertEquals(EXPECTED_HASH_BASE64, hashB64)
}
@Test
@ -214,8 +219,8 @@ class ScryptSecureHasherTest {
}
// Assert
assert hexResults.every { it == EXPECTED_HASH_HEX }
assert B64Results.every { it == EXPECTED_HASH_BASE64 }
hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result))
}
/**
@ -249,8 +254,8 @@ class ScryptSecureHasherTest {
// Assert
final long MIN_DURATION_NANOS = 75_000_000 // 75 ms
assert resultDurations.min() > MIN_DURATION_NANOS
assert resultDurations.sum() / testIterations > MIN_DURATION_NANOS
assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
}
@Test
@ -262,24 +267,16 @@ class ScryptSecureHasherTest {
boolean valid = ScryptSecureHasher.isRValid(r)
// Assert
assert valid
assertTrue(valid)
}
@Test
void testShouldFailRBoundary() throws Exception {
// Arrange
def rValues = [-8, 0, 2147483647]
List<Integer> rValues = [-8, 0, 2147483647]
// Act
def results = rValues.collect { rValue ->
def isValid = ScryptSecureHasher.isRValid(rValue)
[rValue, isValid]
}
// Assert
results.each { rValue, isRValid ->
assert !isRValid
}
// Act and Assert
rValues.forEach(rValue -> assertFalse(ScryptSecureHasher.isRValid(rValue)))
}
@Test
@ -288,28 +285,19 @@ class ScryptSecureHasherTest {
final Integer n = 16385
final int r = 8
// Act
boolean valid = ScryptSecureHasher.isNValid(n, r)
// Assert
assert valid
// Act and Assert
assertTrue(ScryptSecureHasher.isNValid(n, r))
}
@Test
void testShouldFailNBoundary() throws Exception {
// Arrange
Map costParameters = [(-8): 8, 0: 32]
Map<Integer, Integer> costParameters = [(-8): 8, 0: 32]
// Act
def results = costParameters.collect { n, p ->
def isValid = ScryptSecureHasher.isNValid(n, p)
[n, isValid]
}
// Assert
results.each { n, isNValid ->
assert !isNValid
}
//Act and Assert
costParameters.entrySet().forEach(entry -> {
assertFalse(ScryptSecureHasher.isNValid(entry.getKey(), entry.getValue()))
})
}
@Test
@ -318,19 +306,12 @@ class ScryptSecureHasherTest {
final List<Integer> ps = [1, 8, 1024]
final List<Integer> rs = [8, 1024, 4096]
// Act
def pResults = ps.collectEntries { int p ->
def rResults = rs.collectEntries { int r ->
boolean valid = ScryptSecureHasher.isPValid(p, r)
[r, valid]
}
[p, rResults]
}
// Assert
pResults.each { p, rResult ->
assert rResult.every { r, isValid -> isValid }
}
// Act and Assert
ps.forEach(p -> {
rs.forEach(r -> {
assertTrue(ScryptSecureHasher.isPValid(p, r))
})
})
}
@Test
@ -339,19 +320,12 @@ class ScryptSecureHasherTest {
final List<Integer> ps = [4096 * 64, 1024 * 1024]
final List<Integer> rs = [4096, 1024 * 1024]
// Act
def pResults = ps.collectEntries { int p ->
def rResults = rs.collectEntries { int r ->
boolean valid = ScryptSecureHasher.isPValid(p, r)
[r, valid]
}
[p, rResults]
}
// Assert
pResults.each { p, rResult ->
assert rResult.every { r, isValid -> !isValid }
}
// Act and Assert
ps.forEach(p -> {
rs.forEach(r -> {
assertFalse(ScryptSecureHasher.isPValid(p, r))
})
})
}
@Test
@ -363,7 +337,7 @@ class ScryptSecureHasherTest {
boolean valid = ScryptSecureHasher.isDKLengthValid(dkLength)
// Assert
assert valid
assertTrue(valid)
}
@Test
@ -371,16 +345,10 @@ class ScryptSecureHasherTest {
// Arrange
def dKLengths = [-8, 0, 2147483647]
// Act
def results = dKLengths.collect { dKLength ->
def isValid = ScryptSecureHasher.isDKLengthValid(dKLength)
[dKLength, isValid]
}
// Assert
results.each { dKLength, isDKLengthValid ->
assert !isDKLengthValid
}
// Act and Assert
dKLengths.forEach( dKLength -> {
assertFalse(ScryptSecureHasher.isDKLengthValid(dKLength))
})
}
@Test
@ -388,16 +356,11 @@ class ScryptSecureHasherTest {
// Arrange
def saltLengths = [0, 64]
// Act
def results = saltLengths.collect { saltLength ->
def isValid = new ScryptSecureHasher().isSaltLengthValid(saltLength)
[saltLength, isValid]
}
// Assert
results.each { saltLength, isSaltLengthValid ->
assert { it == isSaltLengthValid }
}
// Act and Assert
ScryptSecureHasher scryptSecureHasher = new ScryptSecureHasher()
saltLengths.forEach(saltLength -> {
assertTrue(scryptSecureHasher.isSaltLengthValid(saltLength))
})
}
@Test
@ -405,16 +368,10 @@ class ScryptSecureHasherTest {
// Arrange
def saltLengths = [-8, 1, 2147483647]
// Act
def results = saltLengths.collect { saltLength ->
def isValid = new ScryptSecureHasher().isSaltLengthValid(saltLength)
[saltLength, isValid]
}
// Assert
results.each { saltLength, isSaltLengthValid ->
assert !isSaltLengthValid
}
// Act and Assert
ScryptSecureHasher scryptSecureHasher = new ScryptSecureHasher()
saltLengths.forEach(saltLength -> {
assertFalse(scryptSecureHasher.isSaltLengthValid(saltLength))
})
}
}

View File

@ -20,7 +20,6 @@ import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.crypto.scrypt.Scrypt
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import org.slf4j.Logger
@ -29,7 +28,11 @@ import org.slf4j.LoggerFactory
import java.security.SecureRandom
import java.security.Security
import static groovy.test.GroovyAssert.shouldFail
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assumptions.assumeTrue
class ScryptGroovyTest {
@ -71,8 +74,8 @@ class ScryptGroovyTest {
}
// Assert
assert allKeys.size() == RUNS
assert allKeys.every { it == allKeys.first() }
assertEquals(RUNS, allKeys.size())
allKeys.forEach(key -> assertArrayEquals(allKeys.first(), key))
}
/**
@ -122,7 +125,7 @@ class ScryptGroovyTest {
logger.info("Generated ${Hex.encodeHexString(calculatedHash)}")
// Assert
assert calculatedHash == params.hash
assertArrayEquals(params.hash, calculatedHash)
}
}
@ -163,7 +166,7 @@ class ScryptGroovyTest {
logger.info("Generated ${Hex.encodeHexString(calculatedHash)}")
// Assert
assert calculatedHash == HASH
assertArrayEquals(HASH, calculatedHash)
}
@EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true")
@ -203,7 +206,7 @@ class ScryptGroovyTest {
logger.info("Generated ${Hex.encodeHexString(calculatedHash)}")
// Assert
assert calculatedHash == Hex.decodeHex(EXPECTED_KEY_HEX as char[])
assertArrayEquals(Hex.decodeHex(EXPECTED_KEY_HEX as char[]), calculatedHash)
}
@Test
@ -222,8 +225,8 @@ class ScryptGroovyTest {
}
// Assert
assert allHashes.size() == RUNS
assert allHashes.every { it == allHashes.first() }
assertEquals(RUNS, allHashes.size())
allHashes.forEach(hash -> assertEquals(allHashes.first(), hash))
}
@Test
@ -231,14 +234,14 @@ class ScryptGroovyTest {
// Arrange
// The generated salt should be byte[16], encoded as 22 Base64 chars
final def EXPECTED_SALT_PATTERN = /\$.+\$[0-9a-zA-Z\/\+]{22}\$.+/
final EXPECTED_SALT_PATTERN = /\$.+\$[0-9a-zA-Z\/\+]{22}\$.+/
// Act
String calculatedHash = Scrypt.scrypt(PASSWORD, N, R, P, DK_LEN)
logger.info("Generated ${calculatedHash}")
// Assert
assert calculatedHash =~ EXPECTED_SALT_PATTERN
assertTrue((calculatedHash =~ EXPECTED_SALT_PATTERN).matches())
}
@Test
@ -254,12 +257,11 @@ class ScryptGroovyTest {
INVALID_NS.each { int invalidN ->
logger.info("Using N: ${invalidN}")
def msg = shouldFail(IllegalArgumentException) {
Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, invalidN, R, P, DK_LEN)
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, invalidN, R, P, DK_LEN))
// Assert
assert msg =~ "N must be a power of 2 greater than 1|Parameter N is too large"
assertTrue((iae.getMessage() =~ "N must be a power of 2 greater than 1|Parameter N is too large").matches())
}
}
@ -278,13 +280,11 @@ class ScryptGroovyTest {
INVALID_RS.each { int invalidR ->
logger.info("Using r: ${invalidR}")
def msg = shouldFail(IllegalArgumentException) {
byte[] hash = Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, invalidR, largeP, DK_LEN)
logger.info("Generated hash: ${Hex.encodeHexString(hash)}")
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, invalidR, largeP, DK_LEN))
// Assert
assert msg =~ "Parameter r must be 1 or greater|Parameter r is too large"
assertTrue((iae.getMessage() =~ "Parameter r must be 1 or greater|Parameter r is too large").matches())
}
}
@ -300,13 +300,11 @@ class ScryptGroovyTest {
INVALID_PS.each { int invalidP ->
logger.info("Using p: ${invalidP}")
def msg = shouldFail(IllegalArgumentException) {
byte[] hash = Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, R, invalidP, DK_LEN)
logger.info("Generated hash: ${Hex.encodeHexString(hash)}")
}
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, R, invalidP, DK_LEN))
// Assert
assert msg =~ "Parameter p must be 1 or greater|Parameter p is too large"
assertTrue((iae.getMessage() =~ "Parameter p must be 1 or greater|Parameter p is too large").matches())
}
}
@ -322,7 +320,7 @@ class ScryptGroovyTest {
logger.info("Check matches: ${matches}")
// Assert
assert matches
assertTrue(matches)
}
@Test
@ -337,7 +335,7 @@ class ScryptGroovyTest {
logger.info("Check matches: ${matches}")
// Assert
assert !matches
assertFalse(matches)
}
@Test
@ -351,14 +349,12 @@ class ScryptGroovyTest {
// Act
INVALID_PASSWORDS.each { String invalidPassword ->
logger.info("Using password: ${invalidPassword}")
def msg = shouldFail(IllegalArgumentException) {
boolean matches = Scrypt.check(invalidPassword, HASH)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.check(invalidPassword, HASH))
logger.expected(iae.getMessage())
// Assert
assert msg =~ "Password cannot be empty"
assertTrue(iae.getMessage().contains("Password cannot be empty"))
}
}
@ -373,14 +369,12 @@ class ScryptGroovyTest {
// Act
INVALID_HASHES.each { String invalidHash ->
logger.info("Using hash: ${invalidHash}")
def msg = shouldFail(IllegalArgumentException) {
boolean matches = Scrypt.check(PASSWORD, invalidHash)
}
logger.expected(msg)
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> Scrypt.check(PASSWORD, invalidHash))
logger.expected(iae.getMessage())
// Assert
assert msg =~ "Hash cannot be empty|Hash is not properly formatted"
assertTrue((iae.getMessage() =~ "Hash cannot be empty|Hash is not properly formatted").matches())
}
}
@ -401,7 +395,8 @@ class ScryptGroovyTest {
"\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$A",
"\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP",
"\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP",
"\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$" +
"ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP",
"\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$F0801\$AAAAAAAAAAA\$A",
@ -417,7 +412,7 @@ class ScryptGroovyTest {
logger.info("Hash is valid: ${isValidHash}")
// Assert
assert isValidHash
assertTrue(isValidHash)
}
}
@ -436,7 +431,7 @@ class ScryptGroovyTest {
logger.info("Hash is valid: ${isValidHash}")
// Assert
assert !isValidHash
assertFalse(isValidHash)
}
}
}

View File

@ -38,9 +38,10 @@ import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.UUID;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class KeyStoreUtilsTest {
private static final String SIGNING_ALGORITHM = "SHA256withRSA";
@ -89,18 +90,18 @@ public class KeyStoreUtilsTest {
);
final TlsConfiguration configuration = KeyStoreUtils.createTlsConfigAndNewKeystoreTruststore(requested, 1, new String[] { HOSTNAME });
final File keystoreFile = new File(configuration.getKeystorePath());
assertTrue("Keystore File not found", keystoreFile.exists());
assertTrue(keystoreFile.exists(), "Keystore File not found");
keystoreFile.deleteOnExit();
final File truststoreFile = new File(configuration.getTruststorePath());
assertTrue("Truststore File not found", truststoreFile.exists());
assertTrue(truststoreFile.exists(),"Truststore File not found");
truststoreFile.deleteOnExit();
assertEquals("Keystore Type not matched", KeystoreType.PKCS12, configuration.getKeystoreType());
assertEquals("Truststore Type not matched", KeystoreType.PKCS12, configuration.getTruststoreType());
assertEquals(KeystoreType.PKCS12, configuration.getKeystoreType(), "Keystore Type not matched");
assertEquals(KeystoreType.PKCS12, configuration.getTruststoreType(), "Truststore Type not matched");
assertTrue("Keystore not valid", KeyStoreUtils.isStoreValid(keystoreFile.toURI().toURL(), configuration.getKeystoreType(), configuration.getKeystorePassword().toCharArray()));
assertTrue("Truststore not valid", KeyStoreUtils.isStoreValid(truststoreFile.toURI().toURL(), configuration.getTruststoreType(), configuration.getTruststorePassword().toCharArray()));
assertTrue(KeyStoreUtils.isStoreValid(keystoreFile.toURI().toURL(), configuration.getKeystoreType(), configuration.getKeystorePassword().toCharArray()), "Keystore not valid");
assertTrue(KeyStoreUtils.isStoreValid(truststoreFile.toURI().toURL(), configuration.getTruststoreType(), configuration.getTruststorePassword().toCharArray()), "Truststore not valid");
}
@Test
@ -141,7 +142,7 @@ public class KeyStoreUtilsTest {
sourceKeyStore.setCertificateEntry(ALIAS, certificate);
final KeyStore copiedKeyStore = copyKeyStore(sourceKeyStore, destinationKeyStore);
assertEquals(String.format("[%s] Certificate not matched", sourceKeyStore.getType()), certificate, copiedKeyStore.getCertificate(ALIAS));
assertEquals(certificate, copiedKeyStore.getCertificate(ALIAS), String.format("[%s] Certificate not matched", sourceKeyStore.getType()));
}
private void assertKeyEntryStoredLoaded(final KeyStore sourceKeyStore, final KeyStore destinationKeyStore) throws GeneralSecurityException, IOException {
@ -151,13 +152,13 @@ public class KeyStoreUtilsTest {
final KeyStore copiedKeyStore = copyKeyStore(sourceKeyStore, destinationKeyStore);
final KeyStore.Entry entry = copiedKeyStore.getEntry(ALIAS, new KeyStore.PasswordProtection(KEY_PASSWORD));
assertTrue(String.format("[%s] Private Key entry not found", sourceKeyStore.getType()), entry instanceof KeyStore.PrivateKeyEntry);
assertInstanceOf(KeyStore.PrivateKeyEntry.class, entry, String.format("[%s] Private Key entry not found", sourceKeyStore.getType()));
final KeyStore.PrivateKeyEntry privateKeyEntry = (KeyStore.PrivateKeyEntry) entry;
final Certificate[] entryCertificateChain = privateKeyEntry.getCertificateChain();
assertArrayEquals(String.format("[%s] Certificate Chain not matched", sourceKeyStore.getType()), certificateChain, entryCertificateChain);
assertEquals(String.format("[%s] Private Key not matched", sourceKeyStore.getType()), keyPair.getPrivate(), privateKeyEntry.getPrivateKey());
assertEquals(String.format("[%s] Public Key not matched", sourceKeyStore.getType()), keyPair.getPublic(), entryCertificateChain[0].getPublicKey());
assertArrayEquals(certificateChain, entryCertificateChain, String.format("[%s] Certificate Chain not matched", sourceKeyStore.getType()));
assertEquals(keyPair.getPrivate(), privateKeyEntry.getPrivateKey(), String.format("[%s] Private Key not matched", sourceKeyStore.getType()));
assertEquals(keyPair.getPublic(), entryCertificateChain[0].getPublicKey(), String.format("[%s] Public Key not matched", sourceKeyStore.getType()));
}
private void assertSecretKeyStoredLoaded(final KeyStore sourceKeyStore, final KeyStore destinationKeyStore) throws GeneralSecurityException, IOException {
@ -167,7 +168,7 @@ public class KeyStoreUtilsTest {
final KeyStore copiedKeyStore = copyKeyStore(sourceKeyStore, destinationKeyStore);
final KeyStore.Entry entry = copiedKeyStore.getEntry(ALIAS, protection);
assertTrue(String.format("[%s] Secret Key entry not found", sourceKeyStore.getType()), entry instanceof KeyStore.SecretKeyEntry);
assertInstanceOf(KeyStore.SecretKeyEntry.class, entry, String.format("[%s] Secret Key entry not found", sourceKeyStore.getType()));
}
private KeyStore copyKeyStore(final KeyStore sourceKeyStore, final KeyStore destinationKeyStore) throws GeneralSecurityException, IOException {

View File

@ -18,7 +18,7 @@ package org.apache.nifi.security.util.crypto;
import org.junit.jupiter.api.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class SecureHasherFactoryTest {

View File

@ -33,7 +33,12 @@ import org.slf4j.LoggerFactory
import java.security.Security
import java.util.concurrent.ArrayBlockingQueue
class PeerSelectorTest extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertNotNull
import static org.junit.jupiter.api.Assertions.assertNull
import static org.junit.jupiter.api.Assertions.assertTrue
class PeerSelectorTest {
private static final Logger logger = LoggerFactory.getLogger(PeerSelectorTest.class)
private static final BOOTSTRAP_PEER_DESCRIPTION = new PeerDescription("localhost", -1, false)
@ -135,7 +140,7 @@ class PeerSelectorTest extends GroovyTestCase {
final Map<String, Double> EXPECTED_PERCENTS,
final int NUM_TIMES = resultsFrequency.values().sum() as int,
final double TOLERANCE = 0.05) {
assert resultsFrequency.keySet() == EXPECTED_PERCENTS.keySet()
assertEquals(EXPECTED_PERCENTS.keySet(), resultsFrequency.keySet())
logger.info(" Actual results: ${resultsFrequency.sort()}")
logger.info("Expected results: ${EXPECTED_PERCENTS.sort().collect { k, v -> "${k}: ${v}%" }}")
@ -154,7 +159,7 @@ class PeerSelectorTest extends GroovyTestCase {
def count = resultsFrequency[k]
def difference = Math.abs(expectedCount - count) / NUM_TIMES
logger.debug("Checking that ${count} is within ±${TOLERANCE * 100}% of ${expectedCount} (${lowerBound}, ${upperBound}) | ${(difference * 100).round(2)}%")
assert count >= lowerBound && count <= upperBound
assertTrue(count >= lowerBound && count <= upperBound)
}
}
@ -169,7 +174,7 @@ class PeerSelectorTest extends GroovyTestCase {
int consecutiveElements = recentPeerSelectionQueue.getMaxConsecutiveElements()
// String mcce = recentPeerSelectionQueue.getMostCommonConsecutiveElement()
// logger.debug("Most consecutive elements in recentPeerSelectionQueue: ${consecutiveElements} - ${mcce} | ${recentPeerSelectionQueue}")
assert consecutiveElements <= recentPeerSelectionQueue.totalSize - 1
assertTrue(consecutiveElements <= recentPeerSelectionQueue.totalSize - 1)
}
private static double calculateMean(Map resultsFrequency) {
@ -179,7 +184,9 @@ class PeerSelectorTest extends GroovyTestCase {
return meanElements.sum() / meanElements.size()
}
private static PeerStatusProvider mockPeerStatusProvider(PeerDescription bootstrapPeerDescription = BOOTSTRAP_PEER_DESCRIPTION, String remoteInstanceUris = DEFAULT_REMOTE_INSTANCE_URIS, Map<PeerDescription, Set<PeerStatus>> peersMap = DEFAULT_PEER_NODES) {
private static PeerStatusProvider mockPeerStatusProvider(PeerDescription bootstrapPeerDescription = BOOTSTRAP_PEER_DESCRIPTION,
String remoteInstanceUris = DEFAULT_REMOTE_INSTANCE_URIS,
Map<PeerDescription, Set<PeerStatus>> peersMap = DEFAULT_PEER_NODES) {
[getTransportProtocol : { ->
SiteToSiteTransportProtocol.HTTP
},
@ -230,8 +237,8 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// Assert
assert peersToQuery.size() == 1
assert peersToQuery.first() == BOOTSTRAP_PEER_DESCRIPTION
assertEquals(1, peersToQuery.size())
assertEquals(BOOTSTRAP_PEER_DESCRIPTION, peersToQuery.first())
}
@Test
@ -249,9 +256,9 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// Assert
assert peersToQuery.size() == restoredPeerStatuses.size() + 1
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION)
assert peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS)
assertEquals(restoredPeerStatuses.size() + 1, peersToQuery.size())
assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
assertTrue(peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS))
}
/**
@ -275,11 +282,11 @@ class PeerSelectorTest extends GroovyTestCase {
}
// Assert
assert peersToQuery.size() == DEFAULT_PEER_STATUSES.size() + 1
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION)
assert peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS)
assertEquals(DEFAULT_PEER_STATUSES.size() + 1, peersToQuery.size())
assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
assertTrue(peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS))
assert repeatedPeersToQuery.every { it == peersToQuery }
repeatedPeersToQuery.forEach(query -> assertEquals(peersToQuery, query))
}
@Test
@ -292,8 +299,8 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Retrieved ${remotePeerStatuses.size()} peer statuses: ${remotePeerStatuses}")
// Assert
assert remotePeerStatuses.size() == DEFAULT_PEER_STATUSES.size()
assert remotePeerStatuses.containsAll(DEFAULT_PEER_STATUSES)
assertEquals(DEFAULT_PEER_STATUSES.size(), remotePeerStatuses.size())
assertTrue(remotePeerStatuses.containsAll(DEFAULT_PEER_STATUSES))
}
/**
@ -339,10 +346,10 @@ class PeerSelectorTest extends GroovyTestCase {
}
// Assert that the send percentage is always between 0% and 80%
assert r.every { k, v -> v.send >= 0 && v.send <= 80 }
r.every { k, v -> assertTrue(v.send >= 0 && v.send <= 80) }
// Assert that the receive percentage is always between 0% and 100%
assert r.every { k, v -> v.receive >= 0 && v.receive <= 100 }
r.every { k, v -> assertTrue(v.receive >= 0 && v.receive <= 100) }
}
}
}
@ -365,8 +372,8 @@ class PeerSelectorTest extends GroovyTestCase {
double receiveWeight = PeerSelector.calculateNormalizedWeight(TransferDirection.RECEIVE, totalFlowfileCount, flowfileCount, NODE_COUNT)
// Assert
assert sendWeight == 100
assert receiveWeight == 100
assertEquals(100, sendWeight)
assertEquals(100, receiveWeight)
}
}
}
@ -392,7 +399,7 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Weighted peer map: ${weightedPeerMap}")
// Assert
assert new ArrayList<>(weightedPeerMap.keySet()) == new ArrayList(clusterMap.keySet())
assertEquals(clusterMap.keySet(), weightedPeerMap.keySet())
}
@Test
@ -416,7 +423,7 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Weighted peer map: ${weightedPeerMap}")
// Assert
assert new ArrayList<>(weightedPeerMap.keySet()) == new ArrayList(clusterMap.keySet())
assertEquals(clusterMap.keySet(), weightedPeerMap.keySet())
}
/**
@ -450,11 +457,11 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Destination map: ${destinationMap}")
// Assert
assert destinationMap.keySet() == peerStatuses
assertEquals(peerStatuses, destinationMap.keySet())
// For uneven splits, the resulting percentage should be within +/- 1%
def totalPercentage = destinationMap.values().sum()
assert totalPercentage >= 99 && totalPercentage <= 100
assertTrue(totalPercentage >= 99 && totalPercentage <= 100)
}
}
}
@ -603,7 +610,7 @@ class PeerSelectorTest extends GroovyTestCase {
// Spot check consecutive selection
if (i % 10 == 0) {
int consecutiveElements = lastN.getMaxConsecutiveElements()
assert consecutiveElements == lastN.size()
assertEquals(lastN.size(), consecutiveElements)
}
}
@ -784,7 +791,8 @@ class PeerSelectorTest extends GroovyTestCase {
cacheFile.deleteOnExit()
// Construct the cache contents and write to disk
final String CACHE_CONTENTS = "${mockPSP.getTransportProtocol()}\n" + "${AbstractPeerPersistence.REMOTE_INSTANCE_URIS_PREFIX}${mockPSP.getRemoteInstanceUris()}\n" + peerStatuses.collect { PeerStatus ps ->
final String CACHE_CONTENTS = "${mockPSP.getTransportProtocol()}\n" +
"${AbstractPeerPersistence.REMOTE_INSTANCE_URIS_PREFIX}${mockPSP.getRemoteInstanceUris()}\n" + peerStatuses.collect { PeerStatus ps ->
[ps.peerDescription.hostname, ps.peerDescription.port, ps.peerDescription.isSecure(), ps.isQueryForPeers()].join(":")
}.join("\n")
cacheFile.text = CACHE_CONTENTS
@ -801,9 +809,9 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// Assert
assert peersToQuery.size() == nodes.size() + 1
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION)
assert peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS)
assertEquals(nodes.size() + 1, peersToQuery.size())
assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
assertTrue(peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS))
}
/**
@ -824,7 +832,8 @@ class PeerSelectorTest extends GroovyTestCase {
cacheFile.deleteOnExit()
// Construct the cache contents and write to disk
final String CACHE_CONTENTS = "${mockPSP.getTransportProtocol()}\n" + "${AbstractPeerPersistence.REMOTE_INSTANCE_URIS_PREFIX}${mockPSP.getRemoteInstanceUris()}\n" + peerStatuses.collect { PeerStatus ps ->
final String CACHE_CONTENTS = "${mockPSP.getTransportProtocol()}\n" +
"${AbstractPeerPersistence.REMOTE_INSTANCE_URIS_PREFIX}${mockPSP.getRemoteInstanceUris()}\n" + peerStatuses.collect { PeerStatus ps ->
[ps.peerDescription.hostname, ps.peerDescription.port, ps.peerDescription.isSecure(), ps.isQueryForPeers()].join(":")
}.join("\n")
cacheFile.text = CACHE_CONTENTS
@ -842,16 +851,16 @@ class PeerSelectorTest extends GroovyTestCase {
// Assert
// The loaded cache should be marked as expired and not used
assert ps.isCacheExpired(ps.peerStatusCache)
assertTrue(ps.isCacheExpired(ps.peerStatusCache))
// This internal method does not refresh or check expiration
def peersToQuery = ps.getPeersToQuery()
logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// The cache has (expired) peer statuses present
assert peersToQuery.size() == nodes.size() + 1
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION)
assert peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS)
assertEquals(nodes.size() + 1, peersToQuery.size())
assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
assertTrue(peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS))
// Trigger the cache expiration detection
ps.refresh()
@ -860,8 +869,8 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("After cache expiration, retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// The cache only contains the bootstrap node
assert peersToQuery.size() == 1
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION)
assertEquals(1, peersToQuery.size())
assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
}
Throwable generateException(String message, int nestedLevel = 0) {
@ -895,8 +904,8 @@ class PeerSelectorTest extends GroovyTestCase {
def peersToQuery = ps.getPeersToQuery()
// Assert
assert peersToQuery.size() == 1
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION)
assertEquals(1, peersToQuery.size())
assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
}
/**
@ -931,8 +940,8 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// Assert
assert peersToQuery.size() == 1
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION)
assertEquals(1, peersToQuery.size())
assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
}
/**
@ -998,7 +1007,7 @@ class PeerSelectorTest extends GroovyTestCase {
ps.refresh()
PeerStatus peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE)
logger.info("Attempt ${currentAttempt} - ${peerStatus}")
assert peerStatus
assertNotNull(peerStatus)
// Force the selector to refresh the cache
currentAttempt++
@ -1009,7 +1018,7 @@ class PeerSelectorTest extends GroovyTestCase {
ps.refresh()
peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE)
logger.info("Attempt ${currentAttempt} - ${peerStatus}")
assert peerStatus == node2Status
assertEquals(node2Status, peerStatus)
// Force the selector to refresh the cache
currentAttempt++
@ -1020,7 +1029,7 @@ class PeerSelectorTest extends GroovyTestCase {
ps.refresh()
peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE)
logger.info("Attempt ${currentAttempt} - ${peerStatus}")
assert !peerStatus
assertNull(peerStatus)
// Force the selector to refresh the cache
currentAttempt = 5
@ -1030,7 +1039,7 @@ class PeerSelectorTest extends GroovyTestCase {
ps.refresh()
peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE)
logger.info("Attempt ${currentAttempt} - ${peerStatus}")
assert peerStatus == bootstrapStatus
assertEquals(bootstrapStatus, peerStatus)
}
// PeerQueue definition and tests
@ -1052,7 +1061,7 @@ class PeerSelectorTest extends GroovyTestCase {
peerQueue.append(nodes.first())
// Assert
assert peerQueue.getMaxConsecutiveElements() == peerQueue.size()
assertEquals(peerQueue.size(), peerQueue.getMaxConsecutiveElements())
}
// Never repeating node
@ -1061,7 +1070,7 @@ class PeerSelectorTest extends GroovyTestCase {
peerQueue.append(nodes.get(i % peerStatuses.size()))
// Assert
assert peerQueue.getMaxConsecutiveElements() == 1
assertEquals(1, peerQueue.getMaxConsecutiveElements())
}
// Repeat up to nodes.size() times but no more
@ -1072,7 +1081,7 @@ class PeerSelectorTest extends GroovyTestCase {
// Assert
// logger.debug("Most consecutive elements in queue: ${peerQueue.getMaxConsecutiveElements()} | ${peerQueue}")
assert peerQueue.getMaxConsecutiveElements() <= peerStatuses.size()
assertTrue(peerQueue.getMaxConsecutiveElements() <= peerStatuses.size())
}
}

View File

@ -29,7 +29,10 @@ import org.slf4j.LoggerFactory
import javax.net.ssl.SSLServerSocket
import java.security.Security
class SocketUtilsTest extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertFalse
class SocketUtilsTest {
private static final Logger logger = LoggerFactory.getLogger(SocketUtilsTest.class)
private static final String KEYSTORE_PATH = "src/test/resources/TlsConfigurationKeystore.jks"
@ -56,7 +59,9 @@ class SocketUtilsTest extends GroovyTestCase {
private NiFiProperties mockNiFiProperties = NiFiProperties.createBasicNiFiProperties(null, DEFAULT_PROPS)
// A static TlsConfiguration referencing the test resource keystore and truststore
// private static final TlsConfiguration TLS_CONFIGURATION = new StandardTlsConfiguration(KEYSTORE_PATH, KEYSTORE_PASSWORD, KEY_PASSWORD, KEYSTORE_TYPE, TRUSTSTORE_PATH, TRUSTSTORE_PASSWORD, TRUSTSTORE_TYPE, PROTOCOL)
// private static final TlsConfiguration TLS_CONFIGURATION =
// new StandardTlsConfiguration(KEYSTORE_PATH, KEYSTORE_PASSWORD, KEY_PASSWORD, KEYSTORE_TYPE, TRUSTSTORE_PATH,
// TRUSTSTORE_PASSWORD, TRUSTSTORE_TYPE, PROTOCOL)
// private static final SSLContext sslContext = SslContextFactory.createSslContext(TLS_CONFIGURATION, ClientAuth.NONE)
@BeforeAll
@ -81,9 +86,9 @@ class SocketUtilsTest extends GroovyTestCase {
// Assert
String[] enabledProtocols = sslServerSocket.getEnabledProtocols()
logger.info("Enabled protocols: ${enabledProtocols}")
assert enabledProtocols == TlsConfiguration.getCurrentSupportedTlsProtocolVersions()
assert !enabledProtocols.contains("TLSv1")
assert !enabledProtocols.contains("TLSv1.1")
assertArrayEquals(TlsConfiguration.getCurrentSupportedTlsProtocolVersions(), enabledProtocols)
assertFalse(enabledProtocols.contains("TLSv1"))
assertFalse(enabledProtocols.contains("TLSv1.1"))
}
@Test
@ -99,8 +104,8 @@ class SocketUtilsTest extends GroovyTestCase {
// Assert
String[] enabledProtocols = sslServerSocket.getEnabledProtocols()
logger.info("Enabled protocols: ${enabledProtocols}")
assert enabledProtocols == TlsConfiguration.getCurrentSupportedTlsProtocolVersions()
assert !enabledProtocols.contains("TLSv1")
assert !enabledProtocols.contains("TLSv1.1")
assertArrayEquals(TlsConfiguration.getCurrentSupportedTlsProtocolVersions(), enabledProtocols)
assertFalse(enabledProtocols.contains("TLSv1"))
assertFalse(enabledProtocols.contains("TLSv1.1"))
}
}

View File

@ -22,8 +22,13 @@ import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.util.concurrent.TimeUnit
import java.util.stream.IntStream
class TestFormatUtilsGroovy extends GroovyTestCase {
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class TestFormatUtilsGroovy {
private static final Logger logger = LoggerFactory.getLogger(TestFormatUtilsGroovy.class)
@BeforeAll
@ -49,7 +54,7 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(days)
// Assert
assert days.every { it == EXPECTED_DAYS }
days.forEach(it -> assertEquals(EXPECTED_DAYS, it))
}
@ -59,14 +64,12 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
final List WEEKS = ["-1 week", "-1 wk", "-1 w", "-1 weeks", "- 1 week"]
// Act
List msgs = WEEKS.collect { String week ->
shouldFail(IllegalArgumentException) {
FormatUtils.getTimeDuration(week, TimeUnit.DAYS)
}
}
// Assert
assert msgs.every { it =~ /Value '.*' is not a valid time duration/ }
WEEKS.stream().forEach(week -> {
IllegalArgumentException iae =
assertThrows(IllegalArgumentException.class, () -> FormatUtils.getTimeDuration(week, TimeUnit.DAYS))
// Assert
assertTrue(iae.message.contains("Value '" + week + "' is not a valid time duration"))
})
}
/**
@ -78,14 +81,12 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
final List WEEKS = ["1 work", "1 wek", "1 k"]
// Act
List msgs = WEEKS.collect { String week ->
shouldFail(IllegalArgumentException) {
FormatUtils.getTimeDuration(week, TimeUnit.DAYS)
}
}
// Assert
assert msgs.every { it =~ /Value '.*' is not a valid time duration/ }
WEEKS.stream().forEach(week -> {
IllegalArgumentException iae =
assertThrows(IllegalArgumentException.class, () -> FormatUtils.getTimeDuration(week, TimeUnit.DAYS))
// Assert
assertTrue(iae.message.contains("Value '" + week + "' is not a valid time duration"))
})
}
@ -105,7 +106,7 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(days)
// Assert
assert days.every { it == EXPECTED_DAYS }
days.forEach(it -> assertEquals(EXPECTED_DAYS, it))
}
/**
@ -130,8 +131,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(parsedDecimalMillis)
// Assert
assert parsedWholeMillis.every { it == EXPECTED_MILLIS }
assert parsedDecimalMillis.every { it == EXPECTED_MILLIS }
parsedWholeMillis.forEach(it -> assertEquals(EXPECTED_MILLIS, it))
parsedDecimalMillis.forEach(it -> assertEquals(EXPECTED_MILLIS, it))
}
/**
@ -158,9 +159,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(oneWeekInOtherUnits)
// Assert
oneWeekInOtherUnits.each { TimeUnit k, double value ->
assert value == ONE_WEEK_IN_OTHER_UNITS[k]
}
oneWeekInOtherUnits.entrySet().forEach(entry ->
assertEquals(ONE_WEEK_IN_OTHER_UNITS.get(entry.getKey()), entry.getValue()))
}
/**
@ -187,9 +187,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(onePointFiveWeeksInOtherUnits)
// Assert
onePointFiveWeeksInOtherUnits.each { TimeUnit k, double value ->
assert value == ONE_POINT_FIVE_WEEKS_IN_OTHER_UNITS[k]
}
onePointFiveWeeksInOtherUnits.entrySet().forEach(entry ->
assertEquals(ONE_POINT_FIVE_WEEKS_IN_OTHER_UNITS.get(entry.getKey()), entry.getValue()))
}
/**
@ -214,8 +213,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(parsedDecimalMillis)
// Assert
assert parsedWholeMillis.every { it == EXPECTED_MILLIS }
assert parsedDecimalMillis.every { it == EXPECTED_MILLIS }
parsedWholeMillis.forEach(it -> assertEquals(EXPECTED_MILLIS, it))
parsedDecimalMillis.forEach(it -> assertEquals(EXPECTED_MILLIS, it))
}
/**
@ -240,9 +239,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.info(results)
// Assert
results.every { String key, double value ->
assert value == SCENARIOS[key].expectedValue
}
results.entrySet().forEach(entry ->
assertEquals(SCENARIOS.get(entry.getKey()).expectedValue, entry.getValue()))
}
/**
@ -265,7 +263,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(parsedWholeNanos)
// Assert
assert parsedWholeNanos.every { it == [EXPECTED_NANOS, TimeUnit.NANOSECONDS] }
parsedWholeNanos.forEach(it ->
assertEquals(Arrays.asList(EXPECTED_NANOS, TimeUnit.NANOSECONDS), it))
}
/**
@ -291,8 +290,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
// Assert
results.every { String key, List values ->
assert values.first() == SCENARIOS[key].expectedValue
assert values.last() == SCENARIOS[key].expectedUnits
assertEquals(SCENARIOS[key].expectedValue, values.first())
assertEquals(SCENARIOS[key].expectedUnits, values.last())
}
}
@ -316,10 +315,10 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.info(results)
// Assert
results.every { String key, List values ->
assert values.first() == SCENARIOS[key].expectedValue
assert values.last() == SCENARIOS[key].expectedUnits
}
results.entrySet().forEach(entry -> {
assertEquals(SCENARIOS.get(entry.getKey()).expectedValue, entry.getValue().get(0))
assertEquals(SCENARIOS.get(entry.getKey()).expectedUnits, entry.getValue().get(1))
})
}
/**
@ -339,17 +338,18 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
]
// Act
List parsedWholeTimes = WHOLE_TIMES.collect { List it ->
List<List<Object>> parsedWholeTimes = WHOLE_TIMES.collect { List it ->
FormatUtils.makeWholeNumberTime(it[0] as float, it[1] as TimeUnit)
}
logger.converted(parsedWholeTimes)
// Assert
parsedWholeTimes.eachWithIndex { List elements, int i ->
assert elements[0] instanceof Long
assert elements[0] == 10L
assert elements[1] == WHOLE_TIMES[i][1]
}
IntStream.range(0, parsedWholeTimes.size())
.forEach(index -> {
List<Object> elements = parsedWholeTimes.get(index)
assertEquals(10L, elements.get(0))
assertEquals(WHOLE_TIMES[index][1], elements.get(1))
})
}
/**
@ -379,7 +379,7 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(parsedWholeTimes)
// Assert
assert parsedWholeTimes == EXPECTED_TIMES
assertEquals(EXPECTED_TIMES, parsedWholeTimes)
}
/**
@ -391,15 +391,10 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
final List UNITS = TimeUnit.values() as List
// Act
def nullMsg = shouldFail(IllegalArgumentException) {
FormatUtils.getSmallerTimeUnit(null)
}
logger.expected(nullMsg)
def nanosMsg = shouldFail(IllegalArgumentException) {
FormatUtils.getSmallerTimeUnit(TimeUnit.NANOSECONDS)
}
logger.expected(nanosMsg)
IllegalArgumentException nullIae = assertThrows(IllegalArgumentException.class,
() -> FormatUtils.getSmallerTimeUnit(null))
IllegalArgumentException nanoIae = assertThrows(IllegalArgumentException.class,
() -> FormatUtils.getSmallerTimeUnit(TimeUnit.NANOSECONDS))
List smallerTimeUnits = UNITS[1..-1].collect { TimeUnit unit ->
FormatUtils.getSmallerTimeUnit(unit)
@ -407,9 +402,9 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(smallerTimeUnits)
// Assert
assert nullMsg == "Cannot determine a smaller time unit than 'null'"
assert nanosMsg == "Cannot determine a smaller time unit than 'NANOSECONDS'"
assert smallerTimeUnits == UNITS[0..<-1]
assertEquals("Cannot determine a smaller time unit than 'null'", nullIae.getMessage())
assertEquals("Cannot determine a smaller time unit than 'NANOSECONDS'", nanoIae.getMessage())
assertEquals(smallerTimeUnits, UNITS.subList(0, UNITS.size() - 1))
}
/**
@ -435,9 +430,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(results)
// Assert
results.every { String key, long value ->
assert value == SCENARIOS[key].expectedMultiplier
}
results.entrySet().forEach(entry ->
assertEquals(SCENARIOS.get(entry.getKey()).expectedMultiplier, entry.getValue()))
}
/**
@ -446,26 +440,21 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
@Test
void testCalculateMultiplierShouldHandleIncorrectUnits() {
// Arrange
final Map SCENARIOS = [
final Map<String, Map<String, TimeUnit>> SCENARIOS = [
"allUnits" : [original: TimeUnit.NANOSECONDS, destination: TimeUnit.DAYS],
"nanosToMicros": [original: TimeUnit.NANOSECONDS, destination: TimeUnit.MICROSECONDS],
"hoursToDays" : [original: TimeUnit.HOURS, destination: TimeUnit.DAYS],
]
// Act
Map results = SCENARIOS.collectEntries { String k, Map values ->
logger.debug("Evaluating ${k}: ${values}")
def msg = shouldFail(IllegalArgumentException) {
FormatUtils.calculateMultiplier(values.original, values.destination)
}
logger.expected(msg)
[k, msg]
}
// Assert
results.every { String key, String value ->
assert value =~ "The original time unit '.*' must be larger than the new time unit '.*'"
}
SCENARIOS.entrySet().stream()
.forEach(entry -> {
// Assert
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
() -> FormatUtils.calculateMultiplier(entry.getValue().get("original"),
entry.getValue().get("destination")))
assertTrue((iae.getMessage() =~ "The original time unit '.*' must be larger than the new time unit '.*'").find())
})
}
// TODO: Microsecond parsing

View File

@ -22,7 +22,7 @@ import org.junit.jupiter.api.Test;
import java.text.DecimalFormatSymbols;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestFormatUtils {
@ -38,7 +38,7 @@ public class TestFormatUtils {
}
@Test
public void testFormatTime() throws Exception {
public void testFormatTime() {
assertEquals("00:00:00.000", FormatUtils.formatHoursMinutesSeconds(0, TimeUnit.DAYS));
assertEquals("01:00:00.000", FormatUtils.formatHoursMinutesSeconds(1, TimeUnit.HOURS));
assertEquals("02:00:00.000", FormatUtils.formatHoursMinutesSeconds(2, TimeUnit.HOURS));

View File

@ -22,9 +22,9 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.nio.charset.StandardCharsets;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
public class TestCompressionInputOutputStreams {
@ -32,7 +32,7 @@ public class TestCompressionInputOutputStreams {
public void testSimple() throws IOException {
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final byte[] data = "Hello, World!".getBytes("UTF-8");
final byte[] data = "Hello, World!".getBytes(StandardCharsets.UTF_8);
final CompressionOutputStream cos = new CompressionOutputStream(baos);
cos.write(data);
@ -43,7 +43,7 @@ public class TestCompressionInputOutputStreams {
final CompressionInputStream cis = new CompressionInputStream(new ByteArrayInputStream(compressedBytes));
final byte[] decompressed = readFully(cis);
assertTrue(Arrays.equals(data, decompressed));
assertArrayEquals(data, decompressed);
}
@Test
@ -54,7 +54,7 @@ public class TestCompressionInputOutputStreams {
for (int i = 0; i < 100; i++) {
sb.append(str);
}
final byte[] data = sb.toString().getBytes("UTF-8");
final byte[] data = sb.toString().getBytes(StandardCharsets.UTF_8);
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
@ -67,13 +67,13 @@ public class TestCompressionInputOutputStreams {
final CompressionInputStream cis = new CompressionInputStream(new ByteArrayInputStream(compressedBytes));
final byte[] decompressed = readFully(cis);
assertTrue(Arrays.equals(data, decompressed));
assertArrayEquals(data, decompressed);
}
@Test
public void testDataLargerThanBufferWhileFlushing() throws IOException {
final String str = "The quick brown fox jumps over the lazy dog\r\n\n\n\r";
final byte[] data = str.getBytes("UTF-8");
final byte[] data = str.getBytes(StandardCharsets.UTF_8);
final StringBuilder sb = new StringBuilder();
final byte[] data1024;
@ -87,19 +87,19 @@ public class TestCompressionInputOutputStreams {
sb.append(str);
}
cos.close();
data1024 = sb.toString().getBytes("UTF-8");
data1024 = sb.toString().getBytes(StandardCharsets.UTF_8);
final byte[] compressedBytes = baos.toByteArray();
final CompressionInputStream cis = new CompressionInputStream(new ByteArrayInputStream(compressedBytes));
final byte[] decompressed = readFully(cis);
assertTrue(Arrays.equals(data1024, decompressed));
assertArrayEquals(data1024, decompressed);
}
@Test
public void testSendingMultipleFilesBackToBackOnSameStream() throws IOException {
final String str = "The quick brown fox jumps over the lazy dog\r\n\n\n\r";
final byte[] data = str.getBytes("UTF-8");
final byte[] data = str.getBytes(StandardCharsets.UTF_8);
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
@ -122,18 +122,18 @@ public class TestCompressionInputOutputStreams {
for (int i = 0; i < 512; i++) {
sb.append(str);
}
data512 = sb.toString().getBytes("UTF-8");
data512 = sb.toString().getBytes(StandardCharsets.UTF_8);
final byte[] compressedBytes = baos.toByteArray();
final ByteArrayInputStream bais = new ByteArrayInputStream(compressedBytes);
final CompressionInputStream cis = new CompressionInputStream(bais);
final byte[] decompressed = readFully(cis);
assertTrue(Arrays.equals(data512, decompressed));
assertArrayEquals(data512, decompressed);
final CompressionInputStream cis2 = new CompressionInputStream(bais);
final byte[] decompressed2 = readFully(cis2);
assertTrue(Arrays.equals(data512, decompressed2));
assertArrayEquals(data512, decompressed2);
}
private byte[] readFully(final InputStream in) throws IOException {

View File

@ -32,7 +32,7 @@ import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertLinesMatch;
public class TestLineDemarcator {
@ -41,14 +41,14 @@ public class TestLineDemarcator {
final String input = "A\nB\nC\rD\r\nE\r\nF\r\rG";
final List<String> lines = getLines(input);
assertEquals(Arrays.asList("A\n", "B\n", "C\r", "D\r\n", "E\r\n", "F\r", "\r", "G"), lines);
assertLinesMatch(Arrays.asList("A\n", "B\n", "C\r", "D\r\n", "E\r\n", "F\r", "\r", "G"), lines);
}
@Test
public void testEmptyStream() throws IOException {
final List<String> lines = getLines("");
assertEquals(Collections.emptyList(), lines);
assertLinesMatch(Collections.emptyList(), lines);
}
@Test
@ -56,7 +56,7 @@ public class TestLineDemarcator {
final String input = "\r\r\r\n\n\n\r\n";
final List<String> lines = getLines(input);
assertEquals(Arrays.asList("\r", "\r", "\r\n", "\n", "\n", "\r\n"), lines);
assertLinesMatch(Arrays.asList("\r", "\r", "\r\n", "\n", "\n", "\r\n"), lines);
}
@Test
@ -64,25 +64,25 @@ public class TestLineDemarcator {
final String input = "ABC\r\nXYZ";
final List<String> lines = getLines(input, 10, 4);
assertEquals(Arrays.asList("ABC\r\n", "XYZ"), lines);
assertLinesMatch(Arrays.asList("ABC\r\n", "XYZ"), lines);
}
@Test
public void testEndsWithCarriageReturn() throws IOException {
final List<String> lines = getLines("ABC\r");
assertEquals(Arrays.asList("ABC\r"), lines);
assertLinesMatch(Arrays.asList("ABC\r"), lines);
}
@Test
public void testEndsWithNewLine() throws IOException {
final List<String> lines = getLines("ABC\n");
assertEquals(Arrays.asList("ABC\n"), lines);
assertLinesMatch(Arrays.asList("ABC\n"), lines);
}
@Test
public void testEndsWithCarriageReturnNewLine() throws IOException {
final List<String> lines = getLines("ABC\r\n");
assertEquals(Arrays.asList("ABC\r\n"), lines);
assertLinesMatch(Arrays.asList("ABC\r\n"), lines);
}
@Test
@ -90,13 +90,13 @@ public class TestLineDemarcator {
final String input = "he\ra-to-a\rb-to-b\rc-to-c\r\nd-to-d";
final List<String> lines = getLines(input, 10, 10);
assertEquals(Arrays.asList("he\r", "a-to-a\r", "b-to-b\r", "c-to-c\r\n", "d-to-d"), lines);
assertLinesMatch(Arrays.asList("he\r", "a-to-a\r", "b-to-b\r", "c-to-c\r\n", "d-to-d"), lines);
}
@Test
public void testFirstCharMatchOnly() throws IOException {
final List<String> lines = getLines("\nThe quick brown fox jumped over the lazy dog.");
assertEquals(Arrays.asList("\n", "The quick brown fox jumped over the lazy dog."), lines);
assertLinesMatch(Arrays.asList("\n", "The quick brown fox jumped over the lazy dog."), lines);
}
@Test

View File

@ -24,11 +24,11 @@ import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;

View File

@ -18,7 +18,7 @@ package org.apache.nifi.util;
import org.junit.jupiter.api.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class StringSelectorTest {
@Test

View File

@ -20,8 +20,8 @@ import org.junit.jupiter.api.Test;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
public class TestTimedBuffer {
@ -90,7 +90,7 @@ public class TestTimedBuffer {
return oldValue;
}
return new TimestampedLong(oldValue.getValue().longValue() + toAdd.getValue().longValue());
return new TimestampedLong(oldValue.getValue() + toAdd.getValue());
}
@Override

View File

@ -22,6 +22,8 @@ import javax.servlet.FilterConfig
import javax.servlet.ServletContext
import javax.servlet.http.HttpServletRequest
import static org.junit.jupiter.api.Assertions.assertEquals
class SanitizeContextPathFilterTest {
private static String getValue(String parameterName, Map<String, String> params = [:]) {
@ -45,7 +47,7 @@ class SanitizeContextPathFilterTest {
scpf.init(mockFilterConfig)
// Assert
assert scpf.getAllowedContextPaths() == EXPECTED_ALLOWED_CONTEXT_PATHS
assertEquals(EXPECTED_ALLOWED_CONTEXT_PATHS, scpf.getAllowedContextPaths())
}
@Test
@ -64,7 +66,7 @@ class SanitizeContextPathFilterTest {
scpf.init(mockFilterConfig)
// Assert
assert scpf.getAllowedContextPaths() == EXPECTED_ALLOWED_CONTEXT_PATHS
assertEquals(EXPECTED_ALLOWED_CONTEXT_PATHS, scpf.getAllowedContextPaths())
}
@Test
@ -113,6 +115,6 @@ class SanitizeContextPathFilterTest {
scpf.injectContextPathAttribute(mockRequest)
// Assert
assert requestAttributes["contextPath"] == EXPECTED_CONTEXT_PATH
assertEquals(EXPECTED_CONTEXT_PATH, requestAttributes["contextPath"])
}
}