NIFI-11035 Replaced remaining JUnit 4 assertions in nifi-commons with JUnit 5

- Replaced Groovy asserts with JUnit 5 assertions and Groovy shouldFail method Junit 5 with assertThrow method

This closes #6880

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
dan-s1 2023-01-23 22:28:49 +00:00 committed by exceptionfactory
parent 84f48b5e8c
commit 53371844a4
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
44 changed files with 924 additions and 1102 deletions

View File

@ -25,7 +25,10 @@ import org.junit.jupiter.api.Test
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
class QueryGroovyTest extends GroovyTestCase { import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertNotEquals
class QueryGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(QueryGroovyTest.class) private static final Logger logger = LoggerFactory.getLogger(QueryGroovyTest.class)
@BeforeAll @BeforeAll
@ -79,11 +82,11 @@ class QueryGroovyTest extends GroovyTestCase {
logger.info("Replace repeating result: ${replaceRepeatingResult.value}") logger.info("Replace repeating result: ${replaceRepeatingResult.value}")
// Assert // Assert
assert replaceSingleResult.value == EXPECTED_SINGLE_RESULT assertEquals(EXPECTED_SINGLE_RESULT, replaceSingleResult.value)
assert replaceSingleResult.resultType == AttributeExpression.ResultType.STRING assertEquals(AttributeExpression.ResultType.STRING, replaceSingleResult.resultType)
assert replaceRepeatingResult.value == EXPECTED_REPEATING_RESULT assertEquals(EXPECTED_REPEATING_RESULT, replaceRepeatingResult.value)
assert replaceRepeatingResult.resultType == AttributeExpression.ResultType.STRING assertEquals(AttributeExpression.ResultType.STRING, replaceRepeatingResult.resultType)
} }
@Test @Test
@ -119,11 +122,11 @@ class QueryGroovyTest extends GroovyTestCase {
logger.info("Replace repeating result: ${replaceRepeatingResult.value}") logger.info("Replace repeating result: ${replaceRepeatingResult.value}")
// Assert // Assert
assert replaceSingleResult.value == EXPECTED_SINGLE_RESULT assertEquals(EXPECTED_SINGLE_RESULT, replaceSingleResult.value)
assert replaceSingleResult.resultType == AttributeExpression.ResultType.STRING assertEquals(AttributeExpression.ResultType.STRING, replaceSingleResult.resultType)
assert replaceRepeatingResult.value == EXPECTED_REPEATING_RESULT assertEquals(EXPECTED_REPEATING_RESULT, replaceRepeatingResult.value)
assert replaceRepeatingResult.resultType == AttributeExpression.ResultType.STRING assertEquals(AttributeExpression.ResultType.STRING, replaceRepeatingResult.resultType)
} }
@Test @Test
@ -159,11 +162,11 @@ class QueryGroovyTest extends GroovyTestCase {
logger.info("Replace repeating result: ${replaceRepeatingResult.value}") logger.info("Replace repeating result: ${replaceRepeatingResult.value}")
// Assert // Assert
assert replaceSingleResult.value == EXPECTED_SINGLE_RESULT assertEquals(EXPECTED_SINGLE_RESULT, replaceSingleResult.value)
assert replaceSingleResult.resultType == AttributeExpression.ResultType.STRING assertEquals(AttributeExpression.ResultType.STRING, replaceSingleResult.resultType)
assert replaceRepeatingResult.value == EXPECTED_REPEATING_RESULT assertEquals(EXPECTED_REPEATING_RESULT, replaceRepeatingResult.value)
assert replaceRepeatingResult.resultType == AttributeExpression.ResultType.STRING assertEquals(AttributeExpression.ResultType.STRING, replaceRepeatingResult.resultType)
} }
@Test @Test
@ -200,10 +203,10 @@ class QueryGroovyTest extends GroovyTestCase {
logger.info("Replace repeating result: ${replaceFirstRepeatingResult}") logger.info("Replace repeating result: ${replaceFirstRepeatingResult}")
// Assert // Assert
assert replaceSingleResult != EXPECTED_SINGLE_RESULT assertNotEquals(EXPECTED_SINGLE_RESULT, replaceSingleResult)
assert replaceRepeatingResult != EXPECTED_REPEATING_RESULT assertNotEquals(EXPECTED_REPEATING_RESULT, replaceRepeatingResult)
assert replaceFirstSingleResult == EXPECTED_SINGLE_RESULT assertEquals(EXPECTED_SINGLE_RESULT, replaceFirstSingleResult)
assert replaceFirstRepeatingResult == EXPECTED_REPEATING_RESULT assertEquals(EXPECTED_REPEATING_RESULT, replaceFirstRepeatingResult)
} }
} }

View File

@ -23,9 +23,9 @@ import org.junit.jupiter.api.Test;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestValueLookup { public class TestValueLookup {

View File

@ -21,12 +21,12 @@ import org.junit.jupiter.api.Test;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.nio.charset.StandardCharsets;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestPackageUnpackageV3 { public class TestPackageUnpackageV3 {
@ -35,7 +35,7 @@ public class TestPackageUnpackageV3 {
final FlowFilePackager packager = new FlowFilePackagerV3(); final FlowFilePackager packager = new FlowFilePackagerV3();
final FlowFileUnpackager unpackager = new FlowFileUnpackagerV3(); final FlowFileUnpackager unpackager = new FlowFileUnpackagerV3();
final byte[] data = "Hello, World!".getBytes("UTF-8"); final byte[] data = "Hello, World!".getBytes(StandardCharsets.UTF_8);
final Map<String, String> map = new HashMap<>(); final Map<String, String> map = new HashMap<>();
map.put("abc", "cba"); map.put("abc", "cba");
@ -50,7 +50,7 @@ public class TestPackageUnpackageV3 {
final byte[] decoded = decodedOut.toByteArray(); final byte[] decoded = decodedOut.toByteArray();
assertEquals(map, unpackagedAttributes); assertEquals(map, unpackagedAttributes);
assertTrue(Arrays.equals(data, decoded)); assertArrayEquals(data, decoded);
} }
} }

View File

@ -28,15 +28,14 @@ import org.apache.nifi.hl7.model.HL7Message;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@SuppressWarnings("resource") @SuppressWarnings("resource")
public class TestHL7Query { public class TestHL7Query {
@ -60,7 +59,7 @@ public class TestHL7Query {
private HL7Message hypoglycemia; private HL7Message hypoglycemia;
@BeforeEach @BeforeEach
public void init() throws IOException, HL7Exception { public void init() throws HL7Exception {
this.hyperglycemia = createMessage(HYPERGLYCEMIA); this.hyperglycemia = createMessage(HYPERGLYCEMIA);
this.hypoglycemia = createMessage(HYPOGLYCEMIA); this.hypoglycemia = createMessage(HYPOGLYCEMIA);
} }
@ -116,7 +115,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testSelectMessage() throws HL7Exception, IOException { public void testSelectMessage() {
final HL7Query query = HL7Query.compile("SELECT MESSAGE"); final HL7Query query = HL7Query.compile("SELECT MESSAGE");
final HL7Message msg = hypoglycemia; final HL7Message msg = hypoglycemia;
final QueryResult result = query.evaluate(msg); final QueryResult result = query.evaluate(msg);
@ -131,7 +130,7 @@ public class TestHL7Query {
@Test @Test
@SuppressWarnings({"unchecked", "rawtypes"}) @SuppressWarnings({"unchecked", "rawtypes"})
public void testSelectField() throws HL7Exception, IOException { public void testSelectField() {
final HL7Query query = HL7Query.compile("SELECT PID.5"); final HL7Query query = HL7Query.compile("SELECT PID.5");
final HL7Message msg = hypoglycemia; final HL7Message msg = hypoglycemia;
final QueryResult result = query.evaluate(msg); final QueryResult result = query.evaluate(msg);
@ -149,7 +148,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testSelectAbnormalTestResult() throws HL7Exception, IOException { public void testSelectAbnormalTestResult() {
final String query = "DECLARE result AS REQUIRED OBX SELECT result WHERE result.7 != 'N' AND result.1 = 1"; final String query = "DECLARE result AS REQUIRED OBX SELECT result WHERE result.7 != 'N' AND result.1 = 1";
final HL7Query hl7Query = HL7Query.compile(query); final HL7Query hl7Query = HL7Query.compile(query);
@ -158,7 +157,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testFieldEqualsString() throws HL7Exception, IOException { public void testFieldEqualsString() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.7 = 'L'"); HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.7 = 'L'");
QueryResult result = hl7Query.evaluate(hypoglycemia); QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch()); assertTrue(result.isMatch());
@ -169,7 +168,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testLessThan() throws HL7Exception, IOException { public void testLessThan() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 < 600"); HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 < 600");
QueryResult result = hl7Query.evaluate(hypoglycemia); QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch()); assertTrue(result.isMatch());
@ -180,7 +179,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testCompareTwoFields() throws HL7Exception, IOException { public void testCompareTwoFields() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 < result.6.2"); HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 < result.6.2");
QueryResult result = hl7Query.evaluate(hypoglycemia); QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch()); assertTrue(result.isMatch());
@ -191,7 +190,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testLessThanOrEqual() throws HL7Exception, IOException { public void testLessThanOrEqual() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 <= 59"); HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 <= 59");
QueryResult result = hl7Query.evaluate(hypoglycemia); QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch()); assertTrue(result.isMatch());
@ -206,7 +205,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testGreaterThanOrEqual() throws HL7Exception, IOException { public void testGreaterThanOrEqual() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 >= 59"); HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 >= 59");
QueryResult result = hl7Query.evaluate(hypoglycemia); QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch()); assertTrue(result.isMatch());
@ -221,7 +220,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testGreaterThan() throws HL7Exception, IOException { public void testGreaterThan() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 > 58"); HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.4 > 58");
QueryResult result = hl7Query.evaluate(hypoglycemia); QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch()); assertTrue(result.isMatch());
@ -236,7 +235,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testDistinctValuesReturned() throws HL7Exception, IOException { public void testDistinctValuesReturned() {
HL7Query hl7Query = HL7Query.compile("DECLARE result1 AS REQUIRED OBX, result2 AS REQUIRED OBX SELECT MESSAGE WHERE result1.7 = 'L' OR result2.7 != 'H'"); HL7Query hl7Query = HL7Query.compile("DECLARE result1 AS REQUIRED OBX, result2 AS REQUIRED OBX SELECT MESSAGE WHERE result1.7 = 'L' OR result2.7 != 'H'");
QueryResult result = hl7Query.evaluate(hypoglycemia); QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch()); assertTrue(result.isMatch());
@ -244,7 +243,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testAndWithParents() throws HL7Exception, IOException { public void testAndWithParents() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.7 = 'L' AND result.3.1 = 'GLU'"); HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.7 = 'L' AND result.3.1 = 'GLU'");
QueryResult result = hl7Query.evaluate(hypoglycemia); QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch()); assertTrue(result.isMatch());
@ -276,7 +275,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testIsNull() throws HL7Exception, IOException { public void testIsNull() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.999 IS NULL"); HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.999 IS NULL");
QueryResult result = hl7Query.evaluate(hypoglycemia); QueryResult result = hl7Query.evaluate(hypoglycemia);
assertTrue(result.isMatch()); assertTrue(result.isMatch());
@ -295,7 +294,7 @@ public class TestHL7Query {
} }
@Test @Test
public void testNotNull() throws HL7Exception, IOException { public void testNotNull() {
HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.999 NOT NULL"); HL7Query hl7Query = HL7Query.compile("DECLARE result AS REQUIRED OBX SELECT MESSAGE WHERE result.999 NOT NULL");
QueryResult result = hl7Query.evaluate(hypoglycemia); QueryResult result = hl7Query.evaluate(hypoglycemia);
assertFalse(result.isMatch()); assertFalse(result.isMatch());
@ -313,7 +312,7 @@ public class TestHL7Query {
assertTrue(result.isMatch()); assertTrue(result.isMatch());
} }
private HL7Message createMessage(final String msgText) throws HL7Exception, IOException { private HL7Message createMessage(final String msgText) throws HL7Exception {
final HapiContext hapiContext = new DefaultHapiContext(); final HapiContext hapiContext = new DefaultHapiContext();
hapiContext.setValidationContext(ValidationContextFactory.noValidation()); hapiContext.setValidationContext(ValidationContextFactory.noValidation());

View File

@ -25,8 +25,8 @@ import org.junit.jupiter.api.Test
import static groovy.json.JsonOutput.prettyPrint import static groovy.json.JsonOutput.prettyPrint
import static groovy.json.JsonOutput.toJson import static groovy.json.JsonOutput.toJson
import static org.junit.Assert.assertFalse import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.Assert.assertTrue import static org.junit.jupiter.api.Assertions.assertTrue
import static org.mockito.Mockito.mock import static org.mockito.Mockito.mock
class TestStandardValidators { class TestStandardValidators {

View File

@ -30,9 +30,9 @@ import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
public class KeyedCipherPropertyEncryptorTest { public class KeyedCipherPropertyEncryptorTest {
private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;

View File

@ -26,9 +26,9 @@ import org.junit.jupiter.api.Test;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
public class PasswordBasedCipherPropertyEncryptorTest { public class PasswordBasedCipherPropertyEncryptorTest {
private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;

View File

@ -23,9 +23,9 @@ import org.junit.jupiter.api.Test;
import java.util.Properties; import java.util.Properties;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
public class PropertyEncryptorFactoryTest { public class PropertyEncryptorFactoryTest {
private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.MD5_256AES; private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.MD5_256AES;

View File

@ -21,8 +21,8 @@ import org.junit.jupiter.api.Test;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
public class StandardPropertySecretKeyProviderTest { public class StandardPropertySecretKeyProviderTest {
private static final String SEED = String.class.getName(); private static final String SEED = String.class.getName();
@ -42,7 +42,7 @@ public class StandardPropertySecretKeyProviderTest {
final SecretKey secretKey = provider.getSecretKey(propertyEncryptionMethod, SEED); final SecretKey secretKey = provider.getSecretKey(propertyEncryptionMethod, SEED);
final int secretKeyLength = secretKey.getEncoded().length; final int secretKeyLength = secretKey.getEncoded().length;
final String message = String.format("Method [%s] Key Length not matched", propertyEncryptionMethod); final String message = String.format("Method [%s] Key Length not matched", propertyEncryptionMethod);
assertEquals(message, propertyEncryptionMethod.getHashLength(), secretKeyLength); assertEquals(propertyEncryptionMethod.getHashLength(), secretKeyLength, message);
} }
} }

View File

@ -30,7 +30,7 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
/** /**
* Abstract base class for tests that walk FieldValue hierarchies. * Abstract base class for tests that walk FieldValue hierarchies.

View File

@ -24,8 +24,8 @@ import org.junit.jupiter.api.Test;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
public class TestFieldValueLogicalPathBuilder extends AbstractWalkerTest { public class TestFieldValueLogicalPathBuilder extends AbstractWalkerTest {

View File

@ -24,7 +24,7 @@ import org.junit.jupiter.api.Test;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestFieldValueWalker extends AbstractWalkerTest { public class TestFieldValueWalker extends AbstractWalkerTest {

View File

@ -29,7 +29,7 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
public class TestSimpleRecordSchema { public class TestSimpleRecordSchema {

View File

@ -52,10 +52,10 @@ import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
@ -589,7 +589,7 @@ public class ResultSetRecordSetTest {
List<RecordField> fields = new ArrayList<>(concreteRecords.size()); List<RecordField> fields = new ArrayList<>(concreteRecords.size());
int i = 1; int i = 1;
for (Record record : concreteRecords) { for (Record record : concreteRecords) {
fields.add(new RecordField("record" + String.valueOf(i), RecordFieldType.RECORD.getRecordDataType(record.getSchema()))); fields.add(new RecordField("record" + i, RecordFieldType.RECORD.getRecordDataType(record.getSchema())));
++i; ++i;
} }
return fields; return fields;
@ -614,7 +614,7 @@ public class ResultSetRecordSetTest {
} }
private void thenAllDataTypesMatchInputFieldType(final List<RecordField> inputFields, final RecordSchema resultSchema) { private void thenAllDataTypesMatchInputFieldType(final List<RecordField> inputFields, final RecordSchema resultSchema) {
assertEquals("The number of input fields does not match the number of fields in the result schema.", inputFields.size(), resultSchema.getFieldCount()); assertEquals(inputFields.size(), resultSchema.getFieldCount(), "The number of input fields does not match the number of fields in the result schema.");
for (int i = 0; i < inputFields.size(); ++i) { for (int i = 0; i < inputFields.size(); ++i) {
assertEquals(inputFields.get(i).getDataType(), resultSchema.getField(i).getDataType()); assertEquals(inputFields.get(i).getDataType(), resultSchema.getField(i).getDataType());
} }
@ -640,7 +640,7 @@ public class ResultSetRecordSetTest {
expectedDataType = RecordFieldType.DECIMAL.getDecimalDataType(decimalDataType.getScale(), decimalDataType.getScale()); expectedDataType = RecordFieldType.DECIMAL.getDecimalDataType(decimalDataType.getScale(), decimalDataType.getScale());
} }
} }
assertEquals("For column " + column.getIndex() + " the converted type is not matching", expectedDataType, actualDataType); assertEquals(expectedDataType, actualDataType, "For column " + column.getIndex() + " the converted type is not matching");
} }
} }
@ -648,9 +648,9 @@ public class ResultSetRecordSetTest {
for (RecordField recordField : actualSchema.getFields()) { for (RecordField recordField : actualSchema.getFields()) {
if (recordField.getDataType() instanceof ArrayDataType) { if (recordField.getDataType() instanceof ArrayDataType) {
ArrayDataType arrayType = (ArrayDataType) recordField.getDataType(); ArrayDataType arrayType = (ArrayDataType) recordField.getDataType();
assertEquals("Array element type for " + recordField.getFieldName() assertEquals(expectedTypes.get(recordField.getFieldName()), arrayType.getElementType(),
+ " is not of expected type " + expectedTypes.get(recordField.getFieldName()).toString(), "Array element type for " + recordField.getFieldName()
expectedTypes.get(recordField.getFieldName()), arrayType.getElementType()); + " is not of expected type " + expectedTypes.get(recordField.getFieldName()).toString());
} else { } else {
fail("RecordField " + recordField.getFieldName() + " is not instance of ArrayDataType"); fail("RecordField " + recordField.getFieldName() + " is not instance of ArrayDataType");
} }
@ -658,7 +658,7 @@ public class ResultSetRecordSetTest {
} }
private void thenAllDataTypesAreChoice(final List<RecordField> inputFields, final RecordSchema resultSchema) { private void thenAllDataTypesAreChoice(final List<RecordField> inputFields, final RecordSchema resultSchema) {
assertEquals("The number of input fields does not match the number of fields in the result schema.", inputFields.size(), resultSchema.getFieldCount()); assertEquals(inputFields.size(), resultSchema.getFieldCount(), "The number of input fields does not match the number of fields in the result schema.");
DataType expectedType = getBroadestChoiceDataType(); DataType expectedType = getBroadestChoiceDataType();
for (int i = 0; i < inputFields.size(); ++i) { for (int i = 0; i < inputFields.size(); ++i) {

View File

@ -16,14 +16,16 @@
*/ */
package org.apache.nifi.security.util package org.apache.nifi.security.util
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
class TlsConfigurationTest extends GroovyTestCase { import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertTrue
class TlsConfigurationTest {
private static final Logger logger = LoggerFactory.getLogger(TlsConfigurationTest.class) private static final Logger logger = LoggerFactory.getLogger(TlsConfigurationTest.class)
@BeforeAll @BeforeAll
@ -33,17 +35,6 @@ class TlsConfigurationTest extends GroovyTestCase {
} }
} }
@BeforeEach
void setUp() {
super.setUp()
}
@AfterEach
void tearDown() {
}
@Test @Test
void testShouldParseJavaVersion() { void testShouldParseJavaVersion() {
// Arrange // Arrange
@ -57,7 +48,8 @@ class TlsConfigurationTest extends GroovyTestCase {
logger.info("Major versions: ${majorVersions}") logger.info("Major versions: ${majorVersions}")
// Assert // Assert
assert majorVersions == (5..12) assertTrue(majorVersions.stream()
.allMatch(num -> num >= 5 && num <= 12))
} }
@Test @Test
@ -72,9 +64,9 @@ class TlsConfigurationTest extends GroovyTestCase {
// Assert // Assert
if (javaMajorVersion < 11) { if (javaMajorVersion < 11) {
assert tlsVersions == ["TLSv1.2"] as String[] assertArrayEquals(new String[]{"TLSv1.2"}, tlsVersions)
} else { } else {
assert tlsVersions == ["TLSv1.3", "TLSv1.2"] as String[] assertArrayEquals(new String[]{"TLSv1.3", "TLSv1.2"}, tlsVersions)
} }
} }
@ -90,9 +82,9 @@ class TlsConfigurationTest extends GroovyTestCase {
// Assert // Assert
if (javaMajorVersion < 11) { if (javaMajorVersion < 11) {
assert tlsVersion == "TLSv1.2" assertEquals("TLSv1.2", tlsVersion)
} else { } else {
assert tlsVersion == "TLSv1.3" assertEquals("TLSv1.3", tlsVersion)
} }
} }
} }

View File

@ -20,7 +20,11 @@ import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x500.style.IETFUtils import org.bouncycastle.asn1.x500.style.IETFUtils
import org.bouncycastle.asn1.x509.* import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.Extensions
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.operator.OperatorCreationException import org.bouncycastle.operator.OperatorCreationException
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequest import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequest
@ -35,16 +39,31 @@ import javax.net.ssl.SSLException
import javax.net.ssl.SSLPeerUnverifiedException import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSession import javax.net.ssl.SSLSession
import javax.net.ssl.SSLSocket import javax.net.ssl.SSLSocket
import java.security.* import java.security.InvalidKeyException
import java.security.KeyPair
import java.security.KeyPairGenerator
import java.security.NoSuchAlgorithmException
import java.security.NoSuchProviderException
import java.security.SignatureException
import java.security.cert.Certificate import java.security.cert.Certificate
import java.security.cert.CertificateException import java.security.cert.CertificateException
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.concurrent.* import java.util.concurrent.Callable
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutionException
import java.util.concurrent.Executors
import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicBoolean
import static org.junit.Assert.assertTrue import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertInstanceOf
import static org.junit.jupiter.api.Assertions.assertNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class CertificateUtilsTest extends GroovyTestCase { class CertificateUtilsTest {
private static final Logger logger = LoggerFactory.getLogger(CertificateUtilsTest.class) private static final Logger logger = LoggerFactory.getLogger(CertificateUtilsTest.class)
private static final int KEY_SIZE = 2048 private static final int KEY_SIZE = 2048
@ -84,10 +103,15 @@ class CertificateUtilsTest extends GroovyTestCase {
* *
* @param dn the DN * @param dn the DN
* @return the certificate * @return the certificate
* @throws IOException* @throws NoSuchAlgorithmException* @throws java.security.cert.CertificateException* @throws java.security.NoSuchProviderException* @throws java.security.SignatureException* @throws java.security.InvalidKeyException* @throws OperatorCreationException * @throws IOException* @throws NoSuchAlgorithmException
* @throws java.security.cert.CertificateException*
* @throws java.security.NoSuchProviderException
* @throws java.security.SignatureException
* @throws OperatorCreationException
*/ */
private private
static X509Certificate generateCertificate(String dn) throws IOException, NoSuchAlgorithmException, CertificateException, NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException { static X509Certificate generateCertificate(String dn) throws IOException, NoSuchAlgorithmException, CertificateException,
NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
KeyPair keyPair = generateKeyPair() KeyPair keyPair = generateKeyPair()
return CertificateUtils.generateSelfSignedX509Certificate(keyPair, dn, SIGNATURE_ALGORITHM, DAYS_IN_YEAR) return CertificateUtils.generateSelfSignedX509Certificate(keyPair, dn, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
} }
@ -99,10 +123,16 @@ class CertificateUtilsTest extends GroovyTestCase {
* @param issuerDn the issuer DN * @param issuerDn the issuer DN
* @param issuerKey the issuer private key * @param issuerKey the issuer private key
* @return the certificate * @return the certificate
* @throws IOException* @throws NoSuchAlgorithmException* @throws CertificateException* @throws NoSuchProviderException* @throws SignatureException* @throws InvalidKeyException* @throws OperatorCreationException * @throws IOException
* @throws NoSuchAlgorithmException
* @throws CertificateException
* @throws NoSuchProviderException*
* @throws SignatureException* @throws InvalidKeyException
* @throws OperatorCreationException
*/ */
private private
static X509Certificate generateIssuedCertificate(String dn, X509Certificate issuer, KeyPair issuerKey) throws IOException, NoSuchAlgorithmException, CertificateException, NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException { static X509Certificate generateIssuedCertificate(String dn, X509Certificate issuer, KeyPair issuerKey) throws IOException,
NoSuchAlgorithmException, CertificateException, NoSuchProviderException, SignatureException, InvalidKeyException, OperatorCreationException {
KeyPair keyPair = generateKeyPair() KeyPair keyPair = generateKeyPair()
return CertificateUtils.generateIssuedCertificate(dn, keyPair.getPublic(), issuer, issuerKey, SIGNATURE_ALGORITHM, DAYS_IN_YEAR) return CertificateUtils.generateIssuedCertificate(dn, keyPair.getPublic(), issuer, issuerKey, SIGNATURE_ALGORITHM, DAYS_IN_YEAR)
} }
@ -143,8 +173,7 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Converted certificate: ${convertedCertificate.class.canonicalName} ${convertedCertificate.subjectDN.toString()} (${convertedCertificate.getSerialNumber()})") logger.info("Converted certificate: ${convertedCertificate.class.canonicalName} ${convertedCertificate.subjectDN.toString()} (${convertedCertificate.getSerialNumber()})")
// Assert // Assert
assert convertedCertificate instanceof X509Certificate assertEquals(EXPECTED_NEW_CERTIFICATE, convertedCertificate)
assert convertedCertificate == EXPECTED_NEW_CERTIFICATE
} }
@Test @Test
@ -163,9 +192,9 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Client auth (noneSocket): ${noneClientAuthStatus}") logger.info("Client auth (noneSocket): ${noneClientAuthStatus}")
// Assert // Assert
assert needClientAuthStatus == ClientAuth.REQUIRED assertEquals(ClientAuth.REQUIRED, needClientAuthStatus)
assert wantClientAuthStatus == ClientAuth.WANT assertEquals(ClientAuth.WANT, wantClientAuthStatus)
assert noneClientAuthStatus == ClientAuth.NONE assertEquals(ClientAuth.NONE, noneClientAuthStatus)
} }
@Test @Test
@ -210,9 +239,7 @@ class CertificateUtilsTest extends GroovyTestCase {
} }
// Assert // Assert
assert resolvedServerDNs.every { String serverDN -> resolvedServerDNs.stream().forEach(serverDN -> assertTrue(CertificateUtils.compareDNs(serverDN, EXPECTED_DN)))
CertificateUtils.compareDNs(serverDN, EXPECTED_DN)
}
} }
@Test @Test
@ -231,7 +258,7 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Extracted client DN: ${clientDN}") logger.info("Extracted client DN: ${clientDN}")
// Assert // Assert
assert !clientDN assertNull(clientDN)
} }
@Test @Test
@ -257,7 +284,7 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Extracted client DN: ${clientDN}") logger.info("Extracted client DN: ${clientDN}")
// Assert // Assert
assert CertificateUtils.compareDNs(clientDN, EXPECTED_DN) assertTrue(CertificateUtils.compareDNs(clientDN, EXPECTED_DN))
} }
@Test @Test
@ -280,7 +307,7 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Extracted client DN: ${clientDN}") logger.info("Extracted client DN: ${clientDN}")
// Assert // Assert
assert CertificateUtils.compareDNs(clientDN, null) assertTrue(CertificateUtils.compareDNs(clientDN, null))
} }
@ -307,7 +334,7 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Extracted client DN: ${clientDN}") logger.info("Extracted client DN: ${clientDN}")
// Assert // Assert
assert CertificateUtils.compareDNs(clientDN, EXPECTED_DN) assertTrue(CertificateUtils.compareDNs(clientDN, EXPECTED_DN))
} }
@Test @Test
@ -326,13 +353,11 @@ class CertificateUtilsTest extends GroovyTestCase {
] as SSLSocket ] as SSLSocket
// Act // Act
def msg = shouldFail(CertificateException) { CertificateException ce = assertThrows(CertificateException.class,
String clientDN = CertificateUtils.extractPeerDNFromSSLSocket(mockSocket) () -> CertificateUtils.extractPeerDNFromSSLSocket(mockSocket))
logger.info("Extracted client DN: ${clientDN}")
}
// Assert // Assert
assert msg =~ "peer not authenticated" assertTrue(ce.getMessage().contains("peer not authenticated"))
} }
@Test @Test
@ -374,14 +399,13 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.matches("DN 1, empty: ${dn1MatchesEmpty}") logger.matches("DN 1, empty: ${dn1MatchesEmpty}")
// Assert // Assert
assert dn1MatchesSelf assertTrue(dn1MatchesReversed)
assert dn1MatchesReversed assertTrue(emptyMatchesEmpty)
assert emptyMatchesEmpty assertTrue(nullMatchesNull)
assert nullMatchesNull
assert !dn1MatchesDn2 assertFalse(dn1MatchesDn2)
assert !dn1MatchesDn2Reversed assertFalse(dn1MatchesDn2Reversed)
assert !dn1MatchesEmpty assertFalse(dn1MatchesEmpty)
} }
@Test @Test
@ -545,10 +569,9 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Issued certificate with subject: ${certificate.getSubjectDN().name} and SAN: ${certificate.getSubjectAlternativeNames().join(",")}") logger.info("Issued certificate with subject: ${certificate.getSubjectDN().name} and SAN: ${certificate.getSubjectAlternativeNames().join(",")}")
// Assert // Assert
assert certificate instanceof X509Certificate assertEquals(SUBJECT_DN, certificate.getSubjectDN().name)
assert certificate.getSubjectDN().name == SUBJECT_DN assertEquals(SANS.size(), certificate.getSubjectAlternativeNames().size())
assert certificate.getSubjectAlternativeNames().size() == SANS.size() assertTrue(certificate.getSubjectAlternativeNames()*.last().containsAll(SANS))
assert certificate.getSubjectAlternativeNames()*.last().containsAll(SANS)
} }
@Test @Test
@ -575,9 +598,9 @@ class CertificateUtilsTest extends GroovyTestCase {
logger.info("Unrelated results: ${unrelatedResults}") logger.info("Unrelated results: ${unrelatedResults}")
// Assert // Assert
assert directResults.every() assertTrue(directResults.every())
assert causedResults.every() assertTrue(causedResults.every())
assert !unrelatedResults.any() assertFalse(unrelatedResults.any())
} }
@Test @Test

View File

@ -30,7 +30,10 @@ import javax.crypto.spec.SecretKeySpec
import java.security.SecureRandom import java.security.SecureRandom
import java.security.Security import java.security.Security
import static groovy.test.GroovyAssert.shouldFail import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assumptions.assumeTrue import static org.junit.jupiter.api.Assumptions.assumeTrue
class AESKeyedCipherProviderGroovyTest { class AESKeyedCipherProviderGroovyTest {
@ -80,7 +83,7 @@ class AESKeyedCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -107,7 +110,7 @@ class AESKeyedCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -148,7 +151,7 @@ class AESKeyedCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
} }
@ -161,12 +164,11 @@ class AESKeyedCipherProviderGroovyTest {
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act // Act
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipherProvider.getCipher(encryptionMethod, null, true) () -> cipherProvider.getCipher(encryptionMethod, null, true))
}
// Assert // Assert
assert msg =~ "The key must be specified" assertTrue(iae.message.contains("The key must be specified"))
} }
@Test @Test
@ -175,17 +177,16 @@ class AESKeyedCipherProviderGroovyTest {
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider()
SecretKey localKey = new SecretKeySpec(Hex.decodeHex("0123456789ABCDEF" as char[]), "AES") SecretKey localKey = new SecretKeySpec(Hex.decodeHex("0123456789ABCDEF" as char[]), "AES")
assert ![128, 192, 256].contains(localKey.encoded.length) assertFalse([128, 192, 256].contains(localKey.encoded.length))
final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC final EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act // Act
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipherProvider.getCipher(encryptionMethod, localKey, true) () -> cipherProvider.getCipher(encryptionMethod, localKey, true))
}
// Assert // Assert
assert msg =~ "The key must be of length \\[128, 192, 256\\]" assertTrue(iae.message.contains("The key must be of length [128, 192, 256]"))
} }
@Test @Test
@ -194,12 +195,11 @@ class AESKeyedCipherProviderGroovyTest {
KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider() KeyedCipherProvider cipherProvider = new AESKeyedCipherProvider()
// Act // Act
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipherProvider.getCipher(null, key, true) () -> cipherProvider.getCipher(null, key, true))
}
// Assert // Assert
assert msg =~ "The encryption method must be specified" assertTrue(iae.message.contains("The encryption method must be specified"))
} }
@Test @Test
@ -210,12 +210,11 @@ class AESKeyedCipherProviderGroovyTest {
final EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES final EncryptionMethod encryptionMethod = EncryptionMethod.MD5_128AES
// Act // Act
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipherProvider.getCipher(encryptionMethod, key, true) () -> cipherProvider.getCipher(encryptionMethod, key, true))
}
// Assert // Assert
assert msg =~ " requires a PBECipherProvider" assertTrue(iae.message.contains("requires a PBECipherProvider"))
} }
@Test @Test
@ -244,7 +243,7 @@ class AESKeyedCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
@Test @Test
@ -264,12 +263,11 @@ class AESKeyedCipherProviderGroovyTest {
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8")) byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}") logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipher = cipherProvider.getCipher(em, key, false) () -> cipherProvider.getCipher(em, key, false))
}
// Assert // Assert
assert msg =~ "Cannot decrypt without a valid IV" assertTrue(iae.message.contains("Cannot decrypt without a valid IV"))
} }
} }
@ -290,13 +288,12 @@ class AESKeyedCipherProviderGroovyTest {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, true) Cipher cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, true)
// Decrypt should fail // Decrypt should fail
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, false) () -> cipherProvider.getCipher(encryptionMethod, key, badIV, false))
} logger.warn(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "Cannot decrypt without a valid IV" assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
} }
} }
@ -317,12 +314,11 @@ class AESKeyedCipherProviderGroovyTest {
logger.info("IV after encrypt: ${Hex.encodeHexString(cipher.getIV())}") logger.info("IV after encrypt: ${Hex.encodeHexString(cipher.getIV())}")
// Decrypt should fail // Decrypt should fail
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipher = cipherProvider.getCipher(encryptionMethod, key, badIV, false) () -> cipherProvider.getCipher(encryptionMethod, key, badIV, false))
} logger.warn(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "Cannot decrypt without a valid IV" assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
} }
} }

View File

@ -20,7 +20,6 @@ import org.apache.commons.codec.binary.Base64
import org.apache.commons.codec.binary.Hex import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.EncryptionMethod import org.apache.nifi.security.util.EncryptionMethod
import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.Assumptions
import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
@ -33,9 +32,14 @@ import javax.crypto.spec.SecretKeySpec
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.security.Security import java.security.Security
import static groovy.test.GroovyAssert.shouldFail import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class Argon2CipherProviderGroovyTest extends GroovyTestCase { class Argon2CipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(Argon2CipherProviderGroovyTest.class) private static final Logger logger = LoggerFactory.getLogger(Argon2CipherProviderGroovyTest.class)
private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess" private static final String PLAINTEXT = "ExactBlockSizeRequiredForProcess"
@ -92,7 +96,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -128,9 +132,9 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes) byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes)
logger.sanity("Created cipher text: ${Hex.encodeHexString(rubyCipherBytes)}") logger.sanity("Created cipher text: ${Hex.encodeHexString(rubyCipherBytes)}")
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec) rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec)
assert rubyCipher.doFinal(rubyCipherBytes) == PLAINTEXT.bytes assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(rubyCipherBytes))
logger.sanity("Decrypted generated cipher text successfully") logger.sanity("Decrypted generated cipher text successfully")
assert rubyCipher.doFinal(cipherBytes) == PLAINTEXT.bytes assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(cipherBytes))
logger.sanity("Decrypted external cipher text successfully") logger.sanity("Decrypted external cipher text successfully")
// $argon2id$v=19$m=memory,t=iterations,p=parallelism$saltB64$hashB64 // $argon2id$v=19$m=memory,t=iterations,p=parallelism$saltB64$hashB64
@ -149,7 +153,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
def saltB64 = hashComponents[4] def saltB64 = hashComponents[4]
byte[] salt = Base64.decodeBase64(saltB64) byte[] salt = Base64.decodeBase64(saltB64)
logger.info("Salt: ${Hex.encodeHexString(salt)}") logger.info("Salt: ${Hex.encodeHexString(salt)}")
assert salt == SALT assertArrayEquals(SALT, salt)
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}") logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
logger.info("External cipher text: ${CIPHER_TEXT} ${cipherBytes.length}") logger.info("External cipher text: ${CIPHER_TEXT} ${cipherBytes.length}")
@ -161,7 +165,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
@Test @Test
@ -181,12 +185,11 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, true) Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, true)
// Decrypt should fail // Decrypt should fail
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false))
}
// Assert // Assert
assert msg =~ "Cannot decrypt without a valid IV" assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
} }
} }
@ -214,7 +217,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -244,7 +247,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -263,13 +266,12 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
INVALID_SALTS.each { String salt -> INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}") logger.info("Checking salt ${salt}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true))
} logger.expected(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ LENGTH_MESSAGE assertTrue(iae.getMessage().contains(LENGTH_MESSAGE))
} }
} }
@ -290,7 +292,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true) Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true)
// Assert // Assert
assert cipher assertNotNull(cipher)
} }
} }
@ -303,14 +305,12 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}") logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act // Act
def msg = IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
shouldFail(IllegalArgumentException) { () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true))
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true) logger.expected(iae.getMessage())
}
logger.expected(msg)
// Assert // Assert
assert msg =~ "The salt cannot be empty. To generate a salt, use Argon2CipherProvider#generateSalt()" assertTrue(iae.getMessage().contains("The salt cannot be empty. To generate a salt, use Argon2CipherProvider#generateSalt()"))
} }
@Test @Test
@ -333,13 +333,15 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
// Assert // Assert
boolean isValidFormattedSalt = cipherProvider.isArgon2FormattedSalt(fullSalt) boolean isValidFormattedSalt = cipherProvider.isArgon2FormattedSalt(fullSalt)
logger.info("Salt is Argon2 format: ${isValidFormattedSalt}") logger.info("Salt is Argon2 format: ${isValidFormattedSalt}")
assert isValidFormattedSalt assertTrue(isValidFormattedSalt)
boolean fullSaltIsValidLength = FULL_SALT_LENGTH_RANGE.contains(saltBytes.length) boolean fullSaltIsValidLength = FULL_SALT_LENGTH_RANGE.contains(saltBytes.length)
logger.info("Salt length (${fullSalt.length()}) in valid range (${FULL_SALT_LENGTH_RANGE})") logger.info("Salt length (${fullSalt.length()}) in valid range (${FULL_SALT_LENGTH_RANGE})")
assert fullSaltIsValidLength assertTrue(fullSaltIsValidLength)
assert rawSaltBytes != [(0x00 as byte) * 16] byte [] notExpected = new byte[16]
Arrays.fill(notExpected, 0x00 as byte)
assertFalse(Arrays.equals(notExpected, rawSaltBytes))
} }
@Test @Test
@ -360,13 +362,12 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8")) byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}") logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false) () -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
} logger.expected(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "Cannot decrypt without a valid IV" assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
} }
} }
@ -397,7 +398,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -415,14 +416,13 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
INVALID_KEY_LENGTHS.each { int keyLength -> INVALID_KEY_LENGTHS.each { int keyLength ->
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}") logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption // Initialize a cipher for
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
} logger.expected(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "${keyLength} is not a valid key length for AES" assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
} }
} }
@ -435,12 +435,11 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act // Act
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true) () -> cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true))
}
// Assert // Assert
assert msg =~ "Encryption with an empty password is not supported" assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"))
} }
@Test @Test
@ -463,10 +462,10 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
cipherProvider.parseSalt(FORMATTED_SALT, rawSalt, params) cipherProvider.parseSalt(FORMATTED_SALT, rawSalt, params)
// Assert // Assert
assert rawSalt == EXPECTED_RAW_SALT assertArrayEquals(EXPECTED_RAW_SALT, rawSalt)
assert params[0] == EXPECTED_MEMORY assertEquals(EXPECTED_MEMORY, params[0])
assert params[1] == EXPECTED_PARALLELISM assertEquals(EXPECTED_PARALLELISM, params[1])
assert params[2] == EXPECTED_ITERATIONS assertEquals(EXPECTED_ITERATIONS, params[2])
} }
@Test @Test
@ -485,7 +484,7 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("Argon2 formatted salt: ${isValid}") logger.info("Argon2 formatted salt: ${isValid}")
// Assert // Assert
assert !isValid assertFalse(isValid)
} }
@Test @Test
@ -505,6 +504,6 @@ class Argon2CipherProviderGroovyTest extends GroovyTestCase {
logger.info("rawSalt: ${Hex.encodeHexString(rawSalt)}") logger.info("rawSalt: ${Hex.encodeHexString(rawSalt)}")
// Assert // Assert
assert rawSalt == EXPECTED_RAW_SALT assertArrayEquals(EXPECTED_RAW_SALT, rawSalt)
} }
} }

View File

@ -27,7 +27,12 @@ import org.slf4j.LoggerFactory
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.security.Security import java.security.Security
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class Argon2SecureHasherTest { class Argon2SecureHasherTest {
private static final Logger logger = LoggerFactory.getLogger(Argon2SecureHasherTest.class) private static final Logger logger = LoggerFactory.getLogger(Argon2SecureHasherTest.class)
@ -68,7 +73,7 @@ class Argon2SecureHasherTest {
} }
// Assert // Assert
assert results.every { it == EXPECTED_HASH_HEX } results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
} }
@Test @Test
@ -98,8 +103,8 @@ class Argon2SecureHasherTest {
} }
// Assert // Assert
assert results.unique().size() == results.size() assertTrue(results.unique().size() == results.size())
assert results.every { it != EXPECTED_HASH_HEX } results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
} }
@Test @Test
@ -143,20 +148,20 @@ class Argon2SecureHasherTest {
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input) String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert // Assert
assert staticSaltHash == EXPECTED_HASH_BYTES assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assert arbitrarySaltHash == EXPECTED_HASH_BYTES assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assert differentArbitrarySaltHash != EXPECTED_HASH_BYTES assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assert differentSaltHash != EXPECTED_HASH_BYTES assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assert staticSaltHashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assert arbitrarySaltHashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assert differentArbitrarySaltHashHex != EXPECTED_HASH_HEX assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assert differentSaltHashHex != EXPECTED_HASH_HEX assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assert staticSaltHashBase64 == EXPECTED_HASH_BASE64 assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assert arbitrarySaltHashBase64 == EXPECTED_HASH_BASE64 assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assert differentArbitrarySaltHashBase64 != EXPECTED_HASH_BASE64 assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assert differentSaltHashBase64 != EXPECTED_HASH_BASE64 assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
} }
@Test @Test
@ -198,7 +203,7 @@ class Argon2SecureHasherTest {
logger.info("Generated hash: ${hashHex}") logger.info("Generated hash: ${hashHex}")
// Assert // Assert
assert hashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, hashHex)
} }
@Test @Test
@ -215,7 +220,7 @@ class Argon2SecureHasherTest {
logger.info("Generated hash: ${hashB64}") logger.info("Generated hash: ${hashB64}")
// Assert // Assert
assert hashB64 == EXPECTED_HASH_B64 assertEquals(EXPECTED_HASH_B64, hashB64)
} }
@Test @Test
@ -243,8 +248,8 @@ class Argon2SecureHasherTest {
} }
// Assert // Assert
assert hexResults.every { it == EXPECTED_HASH_HEX } hexResults.forEach(hexResult -> assertEquals(EXPECTED_HASH_HEX, hexResult))
assert b64Results.every { it == EXPECTED_HASH_B64 } b64Results.forEach(b64Result -> assertEquals(EXPECTED_HASH_B64, b64Result))
} }
/** /**
@ -282,8 +287,8 @@ class Argon2SecureHasherTest {
// Assert // Assert
final long MIN_DURATION_NANOS = 500_000_000 // 500 ms final long MIN_DURATION_NANOS = 500_000_000 // 500 ms
assert resultDurations.min() > MIN_DURATION_NANOS assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assert resultDurations.sum() / testIterations > MIN_DURATION_NANOS assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
} }
@Test @Test
@ -295,7 +300,7 @@ class Argon2SecureHasherTest {
boolean valid = Argon2SecureHasher.isHashLengthValid(hashLength) boolean valid = Argon2SecureHasher.isHashLengthValid(hashLength)
// Assert // Assert
assert valid assertTrue(valid)
} }
@Test @Test
@ -312,7 +317,7 @@ class Argon2SecureHasherTest {
// Assert // Assert
results.each { hashLength, isHashLengthValid -> results.each { hashLength, isHashLengthValid ->
logger.info("For hashLength value ${hashLength}, hashLength is ${isHashLengthValid ? "valid" : "invalid"}") logger.info("For hashLength value ${hashLength}, hashLength is ${isHashLengthValid ? "valid" : "invalid"}")
assert !isHashLengthValid assertFalse(isHashLengthValid)
} }
} }
@ -325,7 +330,7 @@ class Argon2SecureHasherTest {
boolean valid = Argon2SecureHasher.isMemorySizeValid(memory) boolean valid = Argon2SecureHasher.isMemorySizeValid(memory)
// Assert // Assert
assert valid assertTrue(valid)
} }
@Test @Test
@ -342,7 +347,7 @@ class Argon2SecureHasherTest {
// Assert // Assert
results.each { memory, isMemorySizeValid -> results.each { memory, isMemorySizeValid ->
logger.info("For memory size ${memory}, memory is ${isMemorySizeValid ? "valid" : "invalid"}") logger.info("For memory size ${memory}, memory is ${isMemorySizeValid ? "valid" : "invalid"}")
assert !isMemorySizeValid assertFalse(isMemorySizeValid)
} }
} }
@ -355,7 +360,7 @@ class Argon2SecureHasherTest {
boolean valid = Argon2SecureHasher.isParallelismValid(parallelism) boolean valid = Argon2SecureHasher.isParallelismValid(parallelism)
// Assert // Assert
assert valid assertTrue(valid)
} }
@Test @Test
@ -372,7 +377,7 @@ class Argon2SecureHasherTest {
// Assert // Assert
results.each { parallelism, isParallelismValid -> results.each { parallelism, isParallelismValid ->
logger.info("For parallelization factor ${parallelism}, parallelism is ${isParallelismValid ? "valid" : "invalid"}") logger.info("For parallelization factor ${parallelism}, parallelism is ${isParallelismValid ? "valid" : "invalid"}")
assert !isParallelismValid assertFalse(isParallelismValid)
} }
} }
@ -385,7 +390,7 @@ class Argon2SecureHasherTest {
boolean valid = Argon2SecureHasher.isIterationsValid(iterations) boolean valid = Argon2SecureHasher.isIterationsValid(iterations)
// Assert // Assert
assert valid assertTrue(valid)
} }
@Test @Test
@ -402,7 +407,7 @@ class Argon2SecureHasherTest {
// Assert // Assert
results.each { iterations, isIterationsValid -> results.each { iterations, isIterationsValid ->
logger.info("For iteration counts ${iterations}, iteration is ${isIterationsValid ? "valid" : "invalid"}") logger.info("For iteration counts ${iterations}, iteration is ${isIterationsValid ? "valid" : "invalid"}")
assert !isIterationsValid assertFalse(isIterationsValid)
} }
} }
@ -411,17 +416,11 @@ class Argon2SecureHasherTest {
// Arrange // Arrange
def saltLengths = [0, 64] def saltLengths = [0, 64]
// Act // Act and Assert
def results = saltLengths.collect { saltLength -> Argon2SecureHasher argon2SecureHasher = new Argon2SecureHasher()
def isValid = new Argon2SecureHasher().isSaltLengthValid(saltLength) saltLengths.forEach(saltLength -> {
[saltLength, isValid] assertTrue(argon2SecureHasher.isSaltLengthValid(saltLength))
} })
// Assert
results.each { saltLength, isSaltLengthValid ->
logger.info("For salt length ${saltLength}, saltLength is ${isSaltLengthValid ? "valid" : "invalid"}")
assert isSaltLengthValid
}
} }
@Test @Test
@ -429,17 +428,9 @@ class Argon2SecureHasherTest {
// Arrange // Arrange
def saltLengths = [-16, 4] def saltLengths = [-16, 4]
// Act // Act and Assert
def results = saltLengths.collect { saltLength -> Argon2SecureHasher argon2SecureHasher = new Argon2SecureHasher()
def isValid = new Argon2SecureHasher().isSaltLengthValid(saltLength) saltLengths.forEach(saltLength -> assertFalse(argon2SecureHasher.isSaltLengthValid(saltLength)))
[saltLength, isValid]
}
// Assert
results.each { saltLength, isSaltLengthValid ->
logger.info("For salt length ${saltLength}, saltLength is ${isSaltLengthValid ? "valid" : "invalid"}")
assert !isSaltLengthValid
}
} }
@Test @Test
@ -460,7 +451,7 @@ class Argon2SecureHasherTest {
} }
// Assert // Assert
assert results[16][0..15] != results[32][0..15] assertFalse(Arrays.equals(Arrays.copyOf(results[16], 16), Arrays.copyOf(results[32], 16)))
// Demonstrates that internal hash truncation is not supported // Demonstrates that internal hash truncation is not supported
// assert results.every { int k, byte[] v -> v[0..15] as byte[] == EXPECTED_HASH} // assert results.every { int k, byte[] v -> v[0..15] as byte[] == EXPECTED_HASH}
} }

View File

@ -34,9 +34,10 @@ import java.nio.charset.StandardCharsets
import java.security.MessageDigest import java.security.MessageDigest
import java.security.Security import java.security.Security
import static groovy.test.GroovyAssert.shouldFail import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.Assert.assertTrue import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assumptions.assumeTrue import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assertions.assertThrows
class BcryptCipherProviderGroovyTest { class BcryptCipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(BcryptCipherProviderGroovyTest.class) private static final Logger logger = LoggerFactory.getLogger(BcryptCipherProviderGroovyTest.class)
@ -88,7 +89,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -118,7 +119,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -150,7 +151,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -171,8 +172,8 @@ class BcryptCipherProviderGroovyTest {
logger.info("Generated ${secureHasherCalculatedHash}") logger.info("Generated ${secureHasherCalculatedHash}")
// Assert // Assert
assert secureHasherCalculatedHash == EXPECTED_HASH assertEquals(EXPECTED_HASH, secureHasherCalculatedHash)
assert secureHasherCalculatedHash == EXPECTED_HASH assertEquals(EXPECTED_HASH, secureHasherCalculatedHash)
} }
@Test @Test
@ -217,8 +218,8 @@ class BcryptCipherProviderGroovyTest {
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes) byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes)
logger.info("Expected cipher text: ${Hex.encodeHexString(rubyCipherBytes)}") logger.info("Expected cipher text: ${Hex.encodeHexString(rubyCipherBytes)}")
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec) rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec)
assert rubyCipher.doFinal(rubyCipherBytes) == PLAINTEXT.bytes assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(rubyCipherBytes))
assert rubyCipher.doFinal(cipherBytes) == PLAINTEXT.bytes assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(cipherBytes))
logger.sanity("Decrypted external cipher text and generated cipher text successfully") logger.sanity("Decrypted external cipher text and generated cipher text successfully")
// Sanity for hash generation // Sanity for hash generation
@ -226,7 +227,7 @@ class BcryptCipherProviderGroovyTest {
logger.sanity("Salt from external: ${FULL_SALT}") logger.sanity("Salt from external: ${FULL_SALT}")
String generatedHash = new String(BCrypt.withDefaults().hash(WORK_FACTOR, BcryptCipherProvider.extractRawSalt(FULL_SALT), PASSWORD.bytes)) String generatedHash = new String(BCrypt.withDefaults().hash(WORK_FACTOR, BcryptCipherProvider.extractRawSalt(FULL_SALT), PASSWORD.bytes))
logger.sanity("Generated hash: ${generatedHash}") logger.sanity("Generated hash: ${generatedHash}")
assert generatedHash == FULL_HASH assertEquals(FULL_HASH, generatedHash)
// Act // Act
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, FULL_SALT.bytes, IV, DEFAULT_KEY_LENGTH, false) Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, FULL_SALT.bytes, IV, DEFAULT_KEY_LENGTH, false)
@ -235,7 +236,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
private static byte[] customB64Decode(String input) { private static byte[] customB64Decode(String input) {
@ -294,7 +295,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
@Test @Test
@ -313,13 +314,12 @@ class BcryptCipherProviderGroovyTest {
INVALID_SALTS.each { String salt -> INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}") logger.info("Checking salt ${salt}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true))
} logger.warn(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "The salt must be of the format \\\$2a\\\$10\\\$gUVbkVzp79H8YaCOsCVZNu\\. To generate a salt, use BcryptCipherProvider#generateSalt" assertTrue(iae.getMessage().contains("The salt must be of the format \$2a\$10\$gUVbkVzp79H8YaCOsCVZNu. To generate a salt, use BcryptCipherProvider#generateSalt"))
} }
} }
@ -346,13 +346,12 @@ class BcryptCipherProviderGroovyTest {
// Two different errors -- one explaining the no-salt method is not supported, and the other for an empty byte[] passed // Two different errors -- one explaining the no-salt method is not supported, and the other for an empty byte[] passed
// Act // Act
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true))
} logger.warn(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "The salt must be of the format .* To generate a salt, use BcryptCipherProvider#generateSalt" assertTrue((iae.getMessage() =~ "The salt must be of the format .* To generate a salt, use BcryptCipherProvider#generateSalt").find())
} }
@Test @Test
@ -375,12 +374,11 @@ class BcryptCipherProviderGroovyTest {
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8")) byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}") logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false) () -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
}
// Assert // Assert
assert msg =~ "Cannot decrypt without a valid IV" assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
} }
} }
@ -414,7 +412,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -436,12 +434,11 @@ class BcryptCipherProviderGroovyTest {
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}") logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption // Initialize a cipher for encryption
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
}
// Assert // Assert
assert msg =~ "${keyLength} is not a valid key length for AES" assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
} }
} }
@ -457,8 +454,9 @@ class BcryptCipherProviderGroovyTest {
logger.info("Salt: ${salt}") logger.info("Salt: ${salt}")
// Assert // Assert
assert salt =~ /^\$2[axy]\$\d{2}\$/
assert salt.contains("\$${workFactor}\$") assertTrue((salt =~ /^\$2[axy]\$\d{2}\$/).find())
assertTrue(salt.contains("\$" + workFactor + "\$"))
} }
/** /**
@ -518,8 +516,8 @@ class BcryptCipherProviderGroovyTest {
logger.info("Verified: ${verificationRecovered}") logger.info("Verified: ${verificationRecovered}")
// Assert // Assert
assert PLAINTEXT == recovered assertEquals(PLAINTEXT, recovered)
assert PLAINTEXT == verificationRecovered assertEquals(PLAINTEXT, verificationRecovered)
} }
} }
@ -549,7 +547,7 @@ class BcryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT == recovered assertEquals(PLAINTEXT, recovered)
} }
} }
@ -566,27 +564,22 @@ class BcryptCipherProviderGroovyTest {
logger.info("Using algorithm: ${em.getAlgorithm()}") logger.info("Using algorithm: ${em.getAlgorithm()}")
// Initialize a cipher for encryption // Initialize a cipher for encryption
def encryptMsg = shouldFail(IllegalArgumentException) { IllegalArgumentException encryptIae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true) () -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, true))
byte[] iv = cipher.getIV()
logger.info("IV: ${Hex.encodeHexString(iv)}") logger.warn("Encrypt error: " + encryptIae.getMessage())
}
logger.expected("Encrypt error: ${encryptMsg}")
byte[] cipherBytes = PLAINTEXT.reverse().getBytes(StandardCharsets.UTF_8) byte[] cipherBytes = PLAINTEXT.reverse().getBytes(StandardCharsets.UTF_8)
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}") logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def decryptMsg = shouldFail(IllegalArgumentException) { IllegalArgumentException decryptIae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(em, PASSWORD, SALT, [0x00] * 16 as byte[], DEFAULT_KEY_LENGTH, false) () -> cipherProvider.getCipher(em, PASSWORD, SALT, [0x00] * 16 as byte[], DEFAULT_KEY_LENGTH, false))
byte[] recoveredBytes = cipher.doFinal(cipherBytes)
String recovered = new String(recoveredBytes, "UTF-8") logger.warn("Decrypt error: " + decryptIae.getMessage())
logger.info("Recovered: ${recovered}")
}
logger.expected("Decrypt error: ${decryptMsg}")
// Assert // Assert
assert encryptMsg =~ "The salt must be of the format" assertTrue(encryptIae.getMessage().contains("The salt must be of the format"))
assert decryptMsg =~ "The salt must be of the format" assertTrue(decryptIae.getMessage().contains("The salt must be of the format"))
} }
@Disabled("This test can be run on a specific machine to evaluate if the default work factor is sufficient") @Disabled("This test can be run on a specific machine to evaluate if the default work factor is sufficient")

View File

@ -26,7 +26,12 @@ import org.slf4j.LoggerFactory
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class BcryptSecureHasherTest { class BcryptSecureHasherTest {
private static final Logger logger = LoggerFactory.getLogger(BcryptSecureHasher) private static final Logger logger = LoggerFactory.getLogger(BcryptSecureHasher)
@ -62,7 +67,7 @@ class BcryptSecureHasherTest {
} }
// Assert // Assert
assert results.every { it == EXPECTED_HASH_HEX } results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
} }
@Test @Test
@ -90,8 +95,8 @@ class BcryptSecureHasherTest {
} }
// Assert // Assert
assert results.unique().size() == results.size() assertEquals(results.size(), results.unique().size())
assert results.every { it != EXPECTED_HASH_HEX } results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
} }
@Test @Test
@ -131,21 +136,20 @@ class BcryptSecureHasherTest {
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input) String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert // Assert
assert staticSaltHash == EXPECTED_HASH_BYTES assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assert arbitrarySaltHash == EXPECTED_HASH_BYTES assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assert differentArbitrarySaltHash != EXPECTED_HASH_BYTES assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assert differentSaltHash != EXPECTED_HASH_BYTES assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assert staticSaltHashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assert arbitrarySaltHashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assert differentArbitrarySaltHashHex != EXPECTED_HASH_HEX assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assert differentSaltHashHex != EXPECTED_HASH_HEX assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assert staticSaltHashBase64 == EXPECTED_HASH_BASE64
assert arbitrarySaltHashBase64 == EXPECTED_HASH_BASE64
assert differentArbitrarySaltHashBase64 != EXPECTED_HASH_BASE64
assert differentSaltHashBase64 != EXPECTED_HASH_BASE64
assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
} }
@Test @Test
@ -182,7 +186,7 @@ class BcryptSecureHasherTest {
logger.info("Generated hash: ${hashHex}") logger.info("Generated hash: ${hashHex}")
// Assert // Assert
assert hashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, hashHex)
} }
@Test @Test
@ -199,7 +203,7 @@ class BcryptSecureHasherTest {
logger.info("Generated hash: ${hashB64}") logger.info("Generated hash: ${hashB64}")
// Assert // Assert
assert hashB64 == EXPECTED_HASH_BASE64 assertEquals(EXPECTED_HASH_BASE64, hashB64)
} }
@Test @Test
@ -227,8 +231,8 @@ class BcryptSecureHasherTest {
} }
// Assert // Assert
assert hexResults.every { it == EXPECTED_HASH_HEX } hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
assert B64Results.every { it == EXPECTED_HASH_BASE64 } B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result))
} }
/** /**
@ -263,8 +267,8 @@ class BcryptSecureHasherTest {
// Assert // Assert
final long MIN_DURATION_NANOS = 75_000_000 // 75 ms final long MIN_DURATION_NANOS = 75_000_000 // 75 ms
assert resultDurations.min() > MIN_DURATION_NANOS assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assert resultDurations.sum() / testIterations > MIN_DURATION_NANOS assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
} }
@Test @Test
@ -272,11 +276,8 @@ class BcryptSecureHasherTest {
// Arrange // Arrange
final int cost = 14 final int cost = 14
// Act // Act and Assert
boolean valid = BcryptSecureHasher.isCostValid(cost) assertTrue(BcryptSecureHasher.isCostValid(cost))
// Assert
assert valid
} }
@Test @Test
@ -284,17 +285,8 @@ class BcryptSecureHasherTest {
// Arrange // Arrange
def costFactors = [-8, 0, 40] def costFactors = [-8, 0, 40]
// Act // Act and Assert
def results = costFactors.collect { costFactor -> costFactors.forEach(costFactor -> assertFalse(BcryptSecureHasher.isCostValid(costFactor)))
def isValid = BcryptSecureHasher.isCostValid(costFactor)
[costFactor, isValid]
}
// Assert
results.each { costFactor, isCostValid ->
logger.info("For cost factor ${costFactor}, cost is ${isCostValid ? "valid" : "invalid"}")
assert !isCostValid
}
} }
@Test @Test
@ -302,16 +294,9 @@ class BcryptSecureHasherTest {
// Arrange // Arrange
def saltLengths = [0, 16] def saltLengths = [0, 16]
// Act // Act and Assert
def results = saltLengths.collect { saltLength -> BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher()
def isValid = new BcryptSecureHasher().isSaltLengthValid(saltLength) saltLengths.forEach(saltLength -> assertTrue(bcryptSecureHasher.isSaltLengthValid(saltLength)))
[saltLength, isValid]
}
// Assert
results.each { saltLength, isSaltLengthValid ->
assert { it == isSaltLengthValid }
}
} }
@Test @Test
@ -319,17 +304,9 @@ class BcryptSecureHasherTest {
// Arrange // Arrange
def saltLengths = [-8, 1] def saltLengths = [-8, 1]
// Act // Act and Assert
def results = saltLengths.collect { saltLength -> BcryptSecureHasher bcryptSecureHasher = new BcryptSecureHasher()
def isValid = new BcryptSecureHasher().isSaltLengthValid(saltLength) saltLengths.forEach(saltLength -> assertFalse(bcryptSecureHasher.isSaltLengthValid(saltLength)))
[saltLength, isValid]
}
// Assert
results.each { saltLength, isSaltLengthValid ->
logger.info("For Salt Length value ${saltLength}, saltLength is ${isSaltLengthValid ? "valid" : "invalid"}")
assert !isSaltLengthValid
}
} }
@Test @Test
@ -350,8 +327,8 @@ class BcryptSecureHasherTest {
logger.info("Converted (B64) ${convertedBase64} to (R64) ${convertedRadix64}") logger.info("Converted (B64) ${convertedBase64} to (R64) ${convertedRadix64}")
// Assert // Assert
assert convertedBase64 == EXPECTED_MIME_B64 assertEquals(EXPECTED_MIME_B64, convertedBase64)
assert convertedRadix64 == INPUT_RADIX_64 assertEquals(INPUT_RADIX_64, convertedRadix64)
} }
@Test @Test
@ -372,8 +349,8 @@ class BcryptSecureHasherTest {
logger.info("Converted (B64) ${convertedBase64} to (R64) ${convertedRadix64}") logger.info("Converted (B64) ${convertedBase64} to (R64) ${convertedRadix64}")
// Assert // Assert
assert convertedBase64 == EXPECTED_MIME_B64 assertEquals(EXPECTED_MIME_B64, convertedBase64)
assert convertedRadix64 == INPUT_RADIX_64 assertEquals(INPUT_RADIX_64, convertedRadix64)
} }
} }

View File

@ -1,64 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License") you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.util.crypto
import org.apache.nifi.security.util.KeyDerivationFunction
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.security.Security
class CipherProviderFactoryGroovyTest extends GroovyTestCase {
private static final Logger logger = LoggerFactory.getLogger(CipherProviderFactoryGroovyTest.class)
private static final Map<KeyDerivationFunction, Class> EXPECTED_CIPHER_PROVIDERS = [
(KeyDerivationFunction.BCRYPT) : BcryptCipherProvider.class,
(KeyDerivationFunction.NIFI_LEGACY) : NiFiLegacyCipherProvider.class,
(KeyDerivationFunction.NONE) : AESKeyedCipherProvider.class,
(KeyDerivationFunction.OPENSSL_EVP_BYTES_TO_KEY): OpenSSLPKCS5CipherProvider.class,
(KeyDerivationFunction.PBKDF2) : PBKDF2CipherProvider.class,
(KeyDerivationFunction.SCRYPT) : ScryptCipherProvider.class,
(KeyDerivationFunction.ARGON2) : Argon2CipherProvider.class
]
@BeforeAll
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Test
void testGetCipherProviderShouldResolveRegisteredKDFs() {
// Arrange
// Act
KeyDerivationFunction.values().each { KeyDerivationFunction kdf ->
logger.info("Expected: ${kdf.kdfName} -> ${EXPECTED_CIPHER_PROVIDERS.get(kdf).simpleName}")
CipherProvider cp = CipherProviderFactory.getCipherProvider(kdf)
logger.info("Resolved: ${kdf.kdfName} -> ${cp.class.simpleName}")
// Assert
assert cp.class == (EXPECTED_CIPHER_PROVIDERS.get(kdf))
}
}
}

View File

@ -27,7 +27,12 @@ import org.slf4j.LoggerFactory
import java.security.Security import java.security.Security
class CipherUtilityGroovyTest extends GroovyTestCase { import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertTrue
class CipherUtilityGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(CipherUtilityGroovyTest.class) private static final Logger logger = LoggerFactory.getLogger(CipherUtilityGroovyTest.class)
// TripleDES must precede DES for automatic grouping precedence // TripleDES must precede DES for automatic grouping precedence
@ -93,7 +98,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Extracted ${cipher} from ${algorithm}") logger.info("Extracted ${cipher} from ${algorithm}")
// Assert // Assert
assert EXPECTED_ALGORITHMS.get(cipher).contains(algorithm) assertTrue(EXPECTED_ALGORITHMS.get(cipher).contains(algorithm))
} }
} }
@ -108,7 +113,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Extracted ${keyLength} from ${algorithm}") logger.info("Extracted ${keyLength} from ${algorithm}")
// Assert // Assert
assert EXPECTED_ALGORITHMS.get(keyLength).contains(algorithm) assertTrue(EXPECTED_ALGORITHMS.get(keyLength).contains(algorithm))
} }
} }
@ -122,7 +127,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${keyLength} for ${algorithm}") logger.info("Checking ${keyLength} for ${algorithm}")
// Assert // Assert
assert CipherUtility.isValidKeyLength(keyLength, CipherUtility.parseCipherFromAlgorithm(algorithm)) assertTrue(CipherUtility.isValidKeyLength(keyLength, CipherUtility.parseCipherFromAlgorithm(algorithm)))
} }
} }
} }
@ -143,9 +148,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${invalidKeyLengths.join(", ")} for ${algorithm}") logger.info("Checking ${invalidKeyLengths.join(", ")} for ${algorithm}")
// Assert // Assert
invalidKeyLengths.each { int invalidKeyLength -> invalidKeyLengths.forEach(invalidKeyLength -> assertFalse(CipherUtility.isValidKeyLength(invalidKeyLength, CipherUtility.parseCipherFromAlgorithm(algorithm))))
assert !CipherUtility.isValidKeyLength(invalidKeyLength, CipherUtility.parseCipherFromAlgorithm(algorithm))
}
} }
} }
} }
@ -160,7 +163,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${keyLength} for ${algorithm}") logger.info("Checking ${keyLength} for ${algorithm}")
// Assert // Assert
assert CipherUtility.isValidKeyLengthForAlgorithm(keyLength, algorithm) assertTrue(CipherUtility.isValidKeyLengthForAlgorithm(keyLength, algorithm))
} }
} }
} }
@ -181,9 +184,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${invalidKeyLengths.join(", ")} for ${algorithm}") logger.info("Checking ${invalidKeyLengths.join(", ")} for ${algorithm}")
// Assert // Assert
invalidKeyLengths.each { int invalidKeyLength -> invalidKeyLengths.forEach(invalidKeyLength -> assertFalse(CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm)))
assert !CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm)
}
} }
} }
@ -191,7 +192,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
String algorithm = "PBEWITHSHA256AND256BITAES-CBC-BC" String algorithm = "PBEWITHSHA256AND256BITAES-CBC-BC"
int invalidKeyLength = 192 int invalidKeyLength = 192
logger.info("Checking ${invalidKeyLength} for ${algorithm}") logger.info("Checking ${invalidKeyLength} for ${algorithm}")
assert !CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm) assertFalse(CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm))
} }
@Test @Test
@ -223,7 +224,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${algorithm} ${validKeySizes} against expected ${EXPECTED_KEY_SIZES}") logger.info("Checking ${algorithm} ${validKeySizes} against expected ${EXPECTED_KEY_SIZES}")
// Assert // Assert
assert validKeySizes == EXPECTED_KEY_SIZES assertEquals(EXPECTED_KEY_SIZES, validKeySizes)
} }
// Act // Act
@ -235,7 +236,7 @@ class CipherUtilityGroovyTest extends GroovyTestCase {
logger.info("Checking ${algorithm} ${validKeySizes} against expected ${EXPECTED_KEY_SIZES}") logger.info("Checking ${algorithm} ${validKeySizes} against expected ${EXPECTED_KEY_SIZES}")
// Assert // Assert
assert validKeySizes == EXPECTED_KEY_SIZES assertEquals(EXPECTED_KEY_SIZES, validKeySizes)
} }
} }
@ -269,10 +270,10 @@ the License. You may obtain a copy of the License at
logger.info("Looking for ${Hex.encodeHexString(kafka)}; found at ${kafkaIndex}") logger.info("Looking for ${Hex.encodeHexString(kafka)}; found at ${kafkaIndex}")
// Assert // Assert
assert apacheIndex == 16 assertEquals(16, apacheIndex)
assert softwareIndex == 23 assertEquals(23, softwareIndex)
assert asfIndex == 44 assertEquals(44, asfIndex)
assert kafkaIndex == -1 assertEquals(-1, kafkaIndex)
} }
@Test @Test
@ -285,7 +286,7 @@ the License. You may obtain a copy of the License at
String SCRYPT_SALT = ScryptCipherProvider.formatSaltForScrypt(PLAIN_SALT, 10, 1, 1) String SCRYPT_SALT = ScryptCipherProvider.formatSaltForScrypt(PLAIN_SALT, 10, 1, 1)
// Act // Act
def results = KeyDerivationFunction.values().findAll { !it.isStrongKDF() }.collectEntries { KeyDerivationFunction weakKdf -> Map<Object, byte[]> results = KeyDerivationFunction.values().findAll { !it.isStrongKDF() }.collectEntries { KeyDerivationFunction weakKdf ->
[weakKdf, CipherUtility.extractRawSalt(PLAIN_SALT, weakKdf)] [weakKdf, CipherUtility.extractRawSalt(PLAIN_SALT, weakKdf)]
} }
@ -295,6 +296,6 @@ the License. You may obtain a copy of the License at
results.put(KeyDerivationFunction.PBKDF2, CipherUtility.extractRawSalt(PLAIN_SALT, KeyDerivationFunction.PBKDF2)) results.put(KeyDerivationFunction.PBKDF2, CipherUtility.extractRawSalt(PLAIN_SALT, KeyDerivationFunction.PBKDF2))
// Assert // Assert
assert results.every { k, v -> v == PLAIN_SALT } results.values().forEach(v -> assertArrayEquals(PLAIN_SALT, v))
} }
} }

View File

@ -25,7 +25,10 @@ import org.slf4j.LoggerFactory
import java.security.Security import java.security.Security
class HashAlgorithmTest extends GroovyTestCase { import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertTrue
class HashAlgorithmTest {
private static final Logger logger = LoggerFactory.getLogger(HashAlgorithmTest.class) private static final Logger logger = LoggerFactory.getLogger(HashAlgorithmTest.class)
@ -48,7 +51,7 @@ class HashAlgorithmTest extends GroovyTestCase {
logger.info("Broken algorithms: ${brokenAlgorithms}") logger.info("Broken algorithms: ${brokenAlgorithms}")
// Assert // Assert
assert brokenAlgorithms == [HashAlgorithm.MD2, HashAlgorithm.MD5, HashAlgorithm.SHA1] assertEquals([HashAlgorithm.MD2, HashAlgorithm.MD5, HashAlgorithm.SHA1], brokenAlgorithms)
} }
@Test @Test
@ -62,11 +65,11 @@ class HashAlgorithmTest extends GroovyTestCase {
} }
// Assert // Assert
assert descriptions.every { descriptions.forEach(description -> assertTrue((description =~ /.* \(\d+ byte output\).*/).find()) )
it =~ /.* \(\d+ byte output\).*/
}
assert descriptions.findAll { it =~ "MD2|MD5|SHA-1" }.every { it =~ /\[WARNING/ } descriptions.stream()
.filter(description -> (description =~ "MD2|MD5|SHA-1").find() )
.forEach(description -> assertTrue(description.contains("WARNING")))
} }
@Test @Test
@ -78,7 +81,7 @@ class HashAlgorithmTest extends GroovyTestCase {
logger.info("Blake2 algorithms: ${blake2Algorithms}") logger.info("Blake2 algorithms: ${blake2Algorithms}")
// Assert // Assert
assert blake2Algorithms == [HashAlgorithm.BLAKE2_160, HashAlgorithm.BLAKE2_256, HashAlgorithm.BLAKE2_384, HashAlgorithm.BLAKE2_512] assertEquals([HashAlgorithm.BLAKE2_160, HashAlgorithm.BLAKE2_256, HashAlgorithm.BLAKE2_384, HashAlgorithm.BLAKE2_512], blake2Algorithms)
} }
@Test @Test
@ -95,8 +98,7 @@ class HashAlgorithmTest extends GroovyTestCase {
HashAlgorithm found = HashAlgorithm.fromName(name) HashAlgorithm found = HashAlgorithm.fromName(name)
// Assert // Assert
assert found instanceof HashAlgorithm assertEquals(name.toUpperCase(), found.name)
assert found.name == name.toUpperCase()
} }
} }
} }

View File

@ -29,7 +29,14 @@ import java.nio.charset.Charset
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.security.Security import java.security.Security
class HashServiceTest extends GroovyTestCase { import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertInstanceOf
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class HashServiceTest {
private static final Logger logger = LoggerFactory.getLogger(HashServiceTest.class) private static final Logger logger = LoggerFactory.getLogger(HashServiceTest.class)
@BeforeAll @BeforeAll
@ -70,9 +77,9 @@ class HashServiceTest extends GroovyTestCase {
// Assert // Assert
if (result instanceof byte[]) { if (result instanceof byte[]) {
assert result == EXPECTED_HASH_BYTES assertArrayEquals(EXPECTED_HASH_BYTES, result)
} else { } else {
assert result == EXPECTED_HASH assertEquals(EXPECTED_HASH, result)
} }
} }
} }
@ -90,11 +97,12 @@ class HashServiceTest extends GroovyTestCase {
logger.info("UTF-16: ${utf16Hash}") logger.info("UTF-16: ${utf16Hash}")
// Assert // Assert
assert utf8Hash != utf16Hash assertNotEquals(utf8Hash, utf16Hash)
} }
/** /**
* This test ensures that the service properly handles UTF-16 encoded data to return it without the Big Endian Byte Order Mark (BOM). Java treats UTF-16 encoded data without a BOM as Big Endian by default on decoding, but when <em>encoding</em>, it inserts a BE BOM in the data. * This test ensures that the service properly handles UTF-16 encoded data to return it without
* the Big Endian Byte Order Mark (BOM). Java treats UTF-16 encoded data without a BOM as Big Endian by default on decoding, but when <em>encoding</em>, it inserts a BE BOM in the data.
* *
* Examples: * Examples:
* *
@ -138,7 +146,7 @@ class HashServiceTest extends GroovyTestCase {
logger.info("${algorithm.name}(${KNOWN_VALUE}, ${charset.name().padLeft(9)}) = ${hash}") logger.info("${algorithm.name}(${KNOWN_VALUE}, ${charset.name().padLeft(9)}) = ${hash}")
// Assert // Assert
assert hash == EXPECTED_SHA_256_HASHES[translateStringToMapKey(charset.name())] assertEquals(EXPECTED_SHA_256_HASHES[translateStringToMapKey(charset.name())], hash)
} }
} }
@ -162,16 +170,15 @@ class HashServiceTest extends GroovyTestCase {
logger.info("Implicit UTF-8 bytes: ${implicitUTF8HashBytesDefault}") logger.info("Implicit UTF-8 bytes: ${implicitUTF8HashBytesDefault}")
// Assert // Assert
assert explicitUTF8Hash == implicitUTF8Hash assertEquals(explicitUTF8Hash, implicitUTF8Hash)
assert explicitUTF8HashBytes == implicitUTF8HashBytes assertArrayEquals(explicitUTF8HashBytes, implicitUTF8HashBytes)
assert explicitUTF8HashBytes == implicitUTF8HashBytesDefault assertArrayEquals(explicitUTF8HashBytes, implicitUTF8HashBytesDefault)
} }
@Test @Test
void testShouldRejectNullAlgorithm() { void testShouldRejectNullAlgorithm() {
// Arrange // Arrange
final String KNOWN_VALUE = "apachenifi" final String KNOWN_VALUE = "apachenifi"
Closure threeArgString = { -> HashService.hashValue(null, KNOWN_VALUE, StandardCharsets.UTF_8) } Closure threeArgString = { -> HashService.hashValue(null, KNOWN_VALUE, StandardCharsets.UTF_8) }
Closure twoArgString = { -> HashService.hashValue(null, KNOWN_VALUE) } Closure twoArgString = { -> HashService.hashValue(null, KNOWN_VALUE) }
Closure threeArgStringRaw = { -> HashService.hashValueRaw(null, KNOWN_VALUE, StandardCharsets.UTF_8) } Closure threeArgStringRaw = { -> HashService.hashValueRaw(null, KNOWN_VALUE, StandardCharsets.UTF_8) }
@ -186,15 +193,10 @@ class HashServiceTest extends GroovyTestCase {
] ]
// Act // Act
scenarios.each { String name, Closure closure -> scenarios.entrySet().forEach(entry -> {
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> entry.getValue().call())
closure.call() assertTrue(iae.message.contains("The hash algorithm cannot be null"))
} })
logger.expected("${name.padLeft(20)}: ${msg}")
// Assert
assert msg =~ "The hash algorithm cannot be null"
}
} }
@Test @Test
@ -216,15 +218,10 @@ class HashServiceTest extends GroovyTestCase {
] ]
// Act // Act
scenarios.each { String name, Closure closure -> scenarios.entrySet().forEach(entry -> {
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> entry.getValue().call())
closure.call() assertTrue(iae.message.contains("The value cannot be null"))
} })
logger.expected("${name.padLeft(20)}: ${msg}")
// Assert
assert msg =~ "The value cannot be null"
}
} }
@Test @Test
@ -266,7 +263,7 @@ class HashServiceTest extends GroovyTestCase {
// Assert // Assert
generatedHashes.each { String algorithmName, String hash -> generatedHashes.each { String algorithmName, String hash ->
String key = translateStringToMapKey(algorithmName) String key = translateStringToMapKey(algorithmName)
assert EXPECTED_HASHES[key] == hash assertEquals(EXPECTED_HASHES[key], hash)
} }
} }
@ -309,7 +306,7 @@ class HashServiceTest extends GroovyTestCase {
// Assert // Assert
generatedHashes.each { String algorithmName, String hash -> generatedHashes.each { String algorithmName, String hash ->
String key = translateStringToMapKey(algorithmName) String key = translateStringToMapKey(algorithmName)
assert EXPECTED_HASHES[key] == hash assertEquals(EXPECTED_HASHES[key], hash)
} }
} }
@ -323,14 +320,14 @@ class HashServiceTest extends GroovyTestCase {
def allowableValues = HashService.buildHashAlgorithmAllowableValues() def allowableValues = HashService.buildHashAlgorithmAllowableValues()
// Assert // Assert
assert allowableValues instanceof AllowableValue[] assertInstanceOf(AllowableValue[].class, allowableValues)
def valuesList = allowableValues as List<AllowableValue> def valuesList = allowableValues as List<AllowableValue>
assert valuesList.size() == EXPECTED_ALGORITHMS.size() assertEquals(EXPECTED_ALGORITHMS.size(), valuesList.size())
EXPECTED_ALGORITHMS.each { HashAlgorithm expectedAlgorithm -> EXPECTED_ALGORITHMS.each { HashAlgorithm expectedAlgorithm ->
def matchingValue = valuesList.find { it.value == expectedAlgorithm.name } def matchingValue = valuesList.find { it.value == expectedAlgorithm.name }
assert matchingValue.displayName == expectedAlgorithm.name assertEquals(expectedAlgorithm.name, matchingValue.displayName)
assert matchingValue.description == expectedAlgorithm.buildAllowableValueDescription() assertEquals(expectedAlgorithm.buildAllowableValueDescription(), matchingValue.description)
} }
} }
@ -347,20 +344,21 @@ class HashServiceTest extends GroovyTestCase {
] ]
logger.info("The consistent list of character sets available [${EXPECTED_CHARACTER_SETS.size()}]: \n${EXPECTED_CHARACTER_SETS.collect { "\t${it.name()}" }.join("\n")}") logger.info("The consistent list of character sets available [${EXPECTED_CHARACTER_SETS.size()}]: \n${EXPECTED_CHARACTER_SETS.collect { "\t${it.name()}" }.join("\n")}")
def expectedDescriptions = ["UTF-16": "This character set normally decodes using an optional BOM at the beginning of the data but encodes by inserting a BE BOM. For hashing, it will be replaced with UTF-16BE. "] def expectedDescriptions =
["UTF-16": "This character set normally decodes using an optional BOM at the beginning of the data but encodes by inserting a BE BOM. For hashing, it will be replaced with UTF-16BE. "]
// Act // Act
def allowableValues = HashService.buildCharacterSetAllowableValues() def allowableValues = HashService.buildCharacterSetAllowableValues()
// Assert // Assert
assert allowableValues instanceof AllowableValue[] assertInstanceOf(AllowableValue[].class, allowableValues)
def valuesList = allowableValues as List<AllowableValue> def valuesList = allowableValues as List<AllowableValue>
assert valuesList.size() == EXPECTED_CHARACTER_SETS.size() assertEquals(EXPECTED_CHARACTER_SETS.size(), valuesList.size())
EXPECTED_CHARACTER_SETS.each { Charset charset -> EXPECTED_CHARACTER_SETS.each { Charset charset ->
def matchingValue = valuesList.find { it.value == charset.name() } def matchingValue = valuesList.find { it.value == charset.name() }
assert matchingValue.displayName == charset.name() assertEquals(charset.name(), matchingValue.displayName)
assert matchingValue.description == (expectedDescriptions[charset.name()] ?: charset.displayName()) assertEquals((expectedDescriptions[charset.name()] ?: charset.displayName()), matchingValue.description)
} }
} }
@ -410,7 +408,7 @@ class HashServiceTest extends GroovyTestCase {
// Assert // Assert
generatedHashes.each { String algorithmName, String hash -> generatedHashes.each { String algorithmName, String hash ->
String key = translateStringToMapKey(algorithmName) String key = translateStringToMapKey(algorithmName)
assert EXPECTED_HASHES[key] == hash assertEquals(EXPECTED_HASHES[key], hash)
} }
} }

View File

@ -21,7 +21,6 @@ import org.apache.nifi.security.util.EncryptionMethod
import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
@ -32,6 +31,8 @@ import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.PBEParameterSpec import javax.crypto.spec.PBEParameterSpec
import java.security.Security import java.security.Security
import static org.junit.jupiter.api.Assertions.assertEquals
class NiFiLegacyCipherProviderGroovyTest { class NiFiLegacyCipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(NiFiLegacyCipherProviderGroovyTest.class) private static final Logger logger = LoggerFactory.getLogger(NiFiLegacyCipherProviderGroovyTest.class)
@ -100,7 +101,7 @@ class NiFiLegacyCipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
} }
@ -130,7 +131,7 @@ class NiFiLegacyCipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
} }
@ -165,7 +166,7 @@ class NiFiLegacyCipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
} }
@ -199,7 +200,7 @@ class NiFiLegacyCipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
} }
@ -230,7 +231,7 @@ class NiFiLegacyCipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
} }
} }

View File

@ -31,8 +31,11 @@ import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.PBEParameterSpec import javax.crypto.spec.PBEParameterSpec
import java.security.Security import java.security.Security
import static groovy.test.GroovyAssert.shouldFail import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.Assert.fail import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assertions.fail
class OpenSSLPKCS5CipherProviderGroovyTest { class OpenSSLPKCS5CipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(OpenSSLPKCS5CipherProviderGroovyTest.class) private static final Logger logger = LoggerFactory.getLogger(OpenSSLPKCS5CipherProviderGroovyTest.class)
@ -99,7 +102,7 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
} }
@ -128,7 +131,7 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
} }
@ -162,7 +165,7 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
} }
@ -196,7 +199,7 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
} }
@ -227,7 +230,7 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
String recovered = new String(recoveredBytes, "UTF-8") String recovered = new String(recoveredBytes, "UTF-8")
// Assert // Assert
assert plaintext.equals(recovered) assertEquals(plaintext, recovered)
} }
} }
@ -242,12 +245,11 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
// Act // Act
logger.info("Using algorithm: null") logger.info("Using algorithm: null")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher providedCipher = cipherProvider.getCipher(null, PASSWORD, SALT, false) () -> cipherProvider.getCipher(null, PASSWORD, SALT, false))
}
// Assert // Assert
assert msg =~ "The encryption method must be specified" assertTrue(iae.getMessage().contains("The encryption method must be specified"))
} }
@Test @Test
@ -261,12 +263,11 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
// Act // Act
logger.info("Using algorithm: ${encryptionMethod}") logger.info("Using algorithm: ${encryptionMethod}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher providedCipher = cipherProvider.getCipher(encryptionMethod, "", SALT, false) () -> cipherProvider.getCipher(encryptionMethod, "", SALT, false))
}
// Assert // Assert
assert msg =~ "Encryption with an empty password is not supported" assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"))
} }
@Test @Test
@ -280,13 +281,11 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
// Act // Act
logger.info("Using algorithm: ${encryptionMethod}") logger.info("Using algorithm: ${encryptionMethod}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
def msg = shouldFail(IllegalArgumentException) { () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, false))
Cipher providedCipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, false)
}
// Assert // Assert
assert msg =~ "Salt must be 8 bytes US-ASCII encoded" assertTrue(iae.getMessage().contains("Salt must be 8 bytes US-ASCII encoded"))
} }
@Test @Test
@ -299,7 +298,9 @@ class OpenSSLPKCS5CipherProviderGroovyTest {
logger.info("Checking salt ${Hex.encodeHexString(salt)}") logger.info("Checking salt ${Hex.encodeHexString(salt)}")
// Assert // Assert
assert salt.length == cipherProvider.getDefaultSaltLength() assertEquals(cipherProvider.getDefaultSaltLength(), salt.length)
assert salt != [(0x00 as byte) * cipherProvider.defaultSaltLength] byte [] notExpected = new byte [cipherProvider.defaultSaltLength]
Arrays.fill(notExpected, 0x00 as byte)
assertFalse(Arrays.equals(notExpected, salt))
} }
} }

View File

@ -28,8 +28,11 @@ import org.slf4j.LoggerFactory
import javax.crypto.Cipher import javax.crypto.Cipher
import java.security.Security import java.security.Security
import static groovy.test.GroovyAssert.shouldFail import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.Assert.assertTrue import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class PBKDF2CipherProviderGroovyTest { class PBKDF2CipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(PBKDF2CipherProviderGroovyTest.class) private static final Logger logger = LoggerFactory.getLogger(PBKDF2CipherProviderGroovyTest.class)
@ -85,7 +88,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -108,12 +111,11 @@ class PBKDF2CipherProviderGroovyTest {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, true) Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, true)
// Decrypt should fail // Decrypt should fail
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, badIV, DEFAULT_KEY_LENGTH, false))
}
// Assert // Assert
assert msg =~ "Cannot decrypt without a valid IV" assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
} }
} }
@ -143,7 +145,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -174,7 +176,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -192,12 +194,11 @@ class PBKDF2CipherProviderGroovyTest {
// Act // Act
logger.info("Using PRF ${prf}") logger.info("Using PRF ${prf}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipherProvider = new PBKDF2CipherProvider(prf, TEST_ITERATION_COUNT) () -> new PBKDF2CipherProvider(prf, TEST_ITERATION_COUNT))
}
// Assert // Assert
assert msg =~ "Cannot resolve empty PRF" assertTrue(iae.getMessage().contains("Cannot resolve empty PRF"))
} }
@Test @Test
@ -234,7 +235,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
@Test @Test
@ -270,7 +271,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -300,7 +301,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
@Test @Test
@ -324,15 +325,15 @@ class PBKDF2CipherProviderGroovyTest {
byte[] sha512CipherBytes = sha512Cipher.doFinal(PLAINTEXT.bytes) byte[] sha512CipherBytes = sha512Cipher.doFinal(PLAINTEXT.bytes)
// Assert // Assert
assert sha512CipherBytes != sha256CipherBytes assertFalse(Arrays.equals(sha512CipherBytes, sha256CipherBytes))
Cipher sha256DecryptCipher = sha256CP.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false) Cipher sha256DecryptCipher = sha256CP.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] sha256RecoveredBytes = sha256DecryptCipher.doFinal(sha256CipherBytes) byte[] sha256RecoveredBytes = sha256DecryptCipher.doFinal(sha256CipherBytes)
assert sha256RecoveredBytes == PLAINTEXT.bytes assertArrayEquals(PLAINTEXT.bytes, sha256RecoveredBytes)
Cipher sha512DecryptCipher = sha512CP.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false) Cipher sha512DecryptCipher = sha512CP.getCipher(encryptionMethod, PASSWORD, SALT, IV, DEFAULT_KEY_LENGTH, false)
byte[] sha512RecoveredBytes = sha512DecryptCipher.doFinal(sha512CipherBytes) byte[] sha512RecoveredBytes = sha512DecryptCipher.doFinal(sha512CipherBytes)
assert sha512RecoveredBytes == PLAINTEXT.bytes assertArrayEquals(PLAINTEXT.bytes, sha512RecoveredBytes)
} }
@Test @Test
@ -355,12 +356,11 @@ class PBKDF2CipherProviderGroovyTest {
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8")) byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}") logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false) () -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
}
// Assert // Assert
assert msg =~ "Cannot decrypt without a valid IV" assertTrue(iae.getMessage().contains( "Cannot decrypt without a valid IV"))
} }
} }
@ -380,12 +380,11 @@ class PBKDF2CipherProviderGroovyTest {
INVALID_SALTS.each { String salt -> INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}") logger.info("Checking salt ${salt}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt?.bytes, DEFAULT_KEY_LENGTH, true) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt?.bytes, DEFAULT_KEY_LENGTH, true))
}
// Assert // Assert
assert msg =~ "The salt must be at least 16 bytes\\. To generate a salt, use PBKDF2CipherProvider#generateSalt" assertTrue(iae.getMessage().contains("The salt must be at least 16 bytes. To generate a salt, use PBKDF2CipherProvider#generateSalt"))
} }
} }
@ -419,7 +418,7 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -441,12 +440,11 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}") logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption // Initialize a cipher for encryption
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
}
// Assert // Assert
assert msg =~ "${keyLength} is not a valid key length for AES" assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
} }
} }
@ -537,7 +535,9 @@ class PBKDF2CipherProviderGroovyTest {
logger.info("Checking salt ${Hex.encodeHexString(salt)}") logger.info("Checking salt ${Hex.encodeHexString(salt)}")
// Assert // Assert
assert salt.length == 16 assertEquals(16,salt.length )
assert salt != [(0x00 as byte) * 16] byte [] notExpected = new byte[16]
Arrays.fill(notExpected, 0x00 as byte)
assertFalse(Arrays.equals(notExpected, salt))
} }
} }

View File

@ -21,8 +21,15 @@ import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.util.stream.Collectors
import java.util.stream.Stream
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class PBKDF2SecureHasherTest { class PBKDF2SecureHasherTest {
@ -31,25 +38,21 @@ class PBKDF2SecureHasherTest {
// Arrange // Arrange
int cost = 10_000 int cost = 10_000
int dkLength = 32 int dkLength = 32
int testIterations = 10
byte[] inputBytes = "This is a sensitive value".bytes byte[] inputBytes = "This is a sensitive value".bytes
final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511" final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511"
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(cost, dkLength)
def results = []
// Act // Act
testIterations.times { int i -> PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(cost, dkLength)
byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes) List<String> results = Stream.iterate(0, n -> n + 1)
String hashHex = new String(Hex.encode(hash)) .limit(10)
results << hashHex .map(iteration -> {
} byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes)
return new String(Hex.encode(hash))
})
.collect(Collectors.toList())
// Assert // Assert
assert results.every { it == EXPECTED_HASH_HEX } results.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
} }
@Test @Test
@ -59,26 +62,22 @@ class PBKDF2SecureHasherTest {
int cost = 10_000 int cost = 10_000
int saltLength = 16 int saltLength = 16
int dkLength = 32 int dkLength = 32
int testIterations = 10
byte[] inputBytes = "This is a sensitive value".bytes byte[] inputBytes = "This is a sensitive value".bytes
final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511" final String EXPECTED_HASH_HEX = "2c47a6d801b71e087f94792079c40880aea29013bfffd0ab94b1bc112ea52511"
//Act
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(prf, cost, saltLength, dkLength) PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher(prf, cost, saltLength, dkLength)
List<String> results = Stream.iterate(0, n -> n + 1)
def results = [] .limit(10)
.map(iteration -> {
// Act byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes)
testIterations.times { int i -> return new String(Hex.encode(hash))
byte[] hash = pbkdf2SecureHasher.hashRaw(inputBytes) })
String hashHex = Hex.encode(hash) .collect(Collectors.toList())
results << hashHex
}
// Assert // Assert
assert results.unique().size() == results.size() assertEquals(results.unique().size(), results.size())
assert results.every { it != EXPECTED_HASH_HEX } results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
} }
@Test @Test
@ -119,20 +118,20 @@ class PBKDF2SecureHasherTest {
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input) String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert // Assert
assert staticSaltHash == EXPECTED_HASH_BYTES assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assert arbitrarySaltHash == EXPECTED_HASH_BYTES assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assert differentArbitrarySaltHash != EXPECTED_HASH_BYTES assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assert differentSaltHash != EXPECTED_HASH_BYTES assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assert staticSaltHashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assert arbitrarySaltHashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assert differentArbitrarySaltHashHex != EXPECTED_HASH_HEX assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assert differentSaltHashHex != EXPECTED_HASH_HEX assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assert staticSaltHashBase64 == EXPECTED_HASH_BASE64 assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assert arbitrarySaltHashBase64 == EXPECTED_HASH_BASE64 assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assert differentArbitrarySaltHashBase64 != EXPECTED_HASH_BASE64 assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assert differentSaltHashBase64 != EXPECTED_HASH_BASE64 assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
} }
@Test @Test
@ -169,7 +168,7 @@ class PBKDF2SecureHasherTest {
String hashHex = pbkdf2SecureHasher.hashHex(input) String hashHex = pbkdf2SecureHasher.hashHex(input)
// Assert // Assert
assert hashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, hashHex)
} }
@Test @Test
@ -185,7 +184,7 @@ class PBKDF2SecureHasherTest {
String hashB64 = pbkdf2SecureHasher.hashBase64(input) String hashB64 = pbkdf2SecureHasher.hashBase64(input)
// Assert // Assert
assert hashB64 == EXPECTED_HASH_BASE64 assertEquals(EXPECTED_HASH_BASE64, hashB64)
} }
@Test @Test
@ -196,23 +195,18 @@ class PBKDF2SecureHasherTest {
final String EXPECTED_HASH_HEX = "7f2d8d8c7aaa45471f6c05a8edfe0a3f75fe01478cc965c5dce664e2ac6f5d0a" final String EXPECTED_HASH_HEX = "7f2d8d8c7aaa45471f6c05a8edfe0a3f75fe01478cc965c5dce664e2ac6f5d0a"
final String EXPECTED_HASH_BASE64 = "fy2NjHqqRUcfbAWo7f4KP3X+AUeMyWXF3OZk4qxvXQo" final String EXPECTED_HASH_BASE64 = "fy2NjHqqRUcfbAWo7f4KP3X+AUeMyWXF3OZk4qxvXQo"
PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher()
def hexResults = []
def B64Results = []
// Act // Act
inputs.each { String input -> PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher()
String hashHex = pbkdf2SecureHasher.hashHex(input) List<String> hexResults = inputs.stream()
hexResults << hashHex .map(input -> pbkdf2SecureHasher.hashHex(input))
.collect(Collectors.toList())
String hashB64 = pbkdf2SecureHasher.hashBase64(input) List<String> B64Results = inputs.stream()
B64Results << hashB64 .map(input -> pbkdf2SecureHasher.hashBase64(input))
} .collect(Collectors.toList())
// Assert // Assert
assert hexResults.every { it == EXPECTED_HASH_HEX } hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
assert B64Results.every { it == EXPECTED_HASH_BASE64 } B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result))
} }
/** /**
@ -246,8 +240,8 @@ class PBKDF2SecureHasherTest {
// Assert // Assert
final long MIN_DURATION_NANOS = 75_000_000 // 75 ms final long MIN_DURATION_NANOS = 75_000_000 // 75 ms
assert resultDurations.min() > MIN_DURATION_NANOS assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assert resultDurations.sum() / testIterations > MIN_DURATION_NANOS assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
} }
@Test @Test
@ -262,99 +256,67 @@ class PBKDF2SecureHasherTest {
} }
// Assert // Assert
assert results.every() assertTrue(results.every())
} }
@Test @Test
void testShouldFailIterationCountBoundary() throws Exception { void testShouldFailIterationCountBoundary() throws Exception {
// Arrange // Arrange
def invalidIterationCounts = [-1, 0, Integer.MAX_VALUE + 1] List<Integer> invalidIterationCounts = [-1, 0, Integer.MAX_VALUE + 1]
// Act // Act and Assert
def results = invalidIterationCounts.collect { i -> invalidIterationCounts.forEach(i -> assertFalse(PBKDF2SecureHasher.isIterationCountValid(i)))
boolean valid = PBKDF2SecureHasher.isIterationCountValid(i)
valid
}
// Assert
results.each { valid ->
assert !valid
}
} }
@Test @Test
void testShouldVerifyDKLengthBoundary() throws Exception { void testShouldVerifyDKLengthBoundary() throws Exception {
// Arrange // Arrange
def validHLengths = [32, 64] List<Integer> validHLengths = [32, 64]
// 1 and MAX_VALUE are the length boundaries, inclusive // 1 and MAX_VALUE are the length boundaries, inclusive
def validDKLengths = [1, 1000, 1_000_000, Integer.MAX_VALUE] List<Integer> validDKLengths = [1, 1000, 1_000_000, Integer.MAX_VALUE]
// Act // Act and Assert
def results = validHLengths.collectEntries { int hLen -> validHLengths.forEach(hLen -> {
def dkResults = validDKLengths.collect { int dkLength -> validDKLengths.forEach(dkLength -> {
boolean valid = PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength) assertTrue(PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength))
valid })
} })
[hLen, dkResults]
}
// Assert
results.each { int hLen, def dkResults ->
assert dkResults.every()
}
} }
@Test @Test
void testShouldFailDKLengthBoundary() throws Exception { void testShouldFailDKLengthBoundary() throws Exception {
// Arrange // Arrange
def validHLengths = [32, 64] List<Integer> validHLengths = [32, 64]
// MAX_VALUE + 1 will become MIN_VALUE because of signed integer math // MAX_VALUE + 1 will become MIN_VALUE because of signed integer math
def invalidDKLengths = [-1, 0, Integer.MAX_VALUE + 1, new Integer(Integer.MAX_VALUE * 2 - 1)] List<Integer> invalidDKLengths = [-1, 0, Integer.MAX_VALUE + 1, new Integer(Integer.MAX_VALUE * 2 - 1)]
// Act // Act and Assert
def results = validHLengths.collectEntries { int hLen -> validHLengths.forEach(hLen -> {
def dkResults = invalidDKLengths.collect { int dkLength -> invalidDKLengths.forEach(dkLength -> {
boolean valid = PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength) assertFalse(PBKDF2SecureHasher.isDKLengthValid(hLen, dkLength))
valid })
} })
[hLen, dkResults]
}
// Assert
results.each { int hLen, def dkResults ->
assert dkResults.every { boolean valid -> !valid }
}
} }
@Test @Test
void testShouldVerifySaltLengthBoundary() throws Exception { void testShouldVerifySaltLengthBoundary() throws Exception {
// Arrange // Arrange
def saltLengths = [0, 16, 64] List<Integer> saltLengths = [0, 16, 64]
// Act // Act and Assert
def results = saltLengths.collect { saltLength -> PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher()
def isValid = new PBKDF2SecureHasher().isSaltLengthValid(saltLength) saltLengths.forEach(saltLength -> assertTrue(pbkdf2SecureHasher.isSaltLengthValid(saltLength)))
isValid
}
// Assert
assert results.every()
} }
@Test @Test
void testShouldFailSaltLengthBoundary() throws Exception { void testShouldFailSaltLengthBoundary() throws Exception {
// Arrange // Arrange
def saltLengths = [-8, 1, Integer.MAX_VALUE + 1] List<Integer> saltLengths = [-8, 1, Integer.MAX_VALUE + 1]
// Act // Act and Assert
def results = saltLengths.collect { saltLength -> PBKDF2SecureHasher pbkdf2SecureHasher = new PBKDF2SecureHasher()
def isValid = new PBKDF2SecureHasher().isSaltLengthValid(saltLength) saltLengths.forEach(saltLength -> assertFalse(pbkdf2SecureHasher.isSaltLengthValid(saltLength)))
isValid
}
// Assert
results.each { assert !it }
} }
} }

View File

@ -35,8 +35,12 @@ import javax.crypto.spec.SecretKeySpec
import java.security.SecureRandom import java.security.SecureRandom
import java.security.Security import java.security.Security
import static groovy.test.GroovyAssert.shouldFail import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.Assert.assertTrue import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotNull
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class ScryptCipherProviderGroovyTest { class ScryptCipherProviderGroovyTest {
private static final Logger logger = LoggerFactory.getLogger(ScryptCipherProviderGroovyTest.class) private static final Logger logger = LoggerFactory.getLogger(ScryptCipherProviderGroovyTest.class)
@ -94,7 +98,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -122,7 +126,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -152,7 +156,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -189,9 +193,9 @@ class ScryptCipherProviderGroovyTest {
byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes) byte[] rubyCipherBytes = rubyCipher.doFinal(PLAINTEXT.bytes)
logger.sanity("Created cipher text: ${Hex.encodeHexString(rubyCipherBytes)}") logger.sanity("Created cipher text: ${Hex.encodeHexString(rubyCipherBytes)}")
rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec) rubyCipher.init(Cipher.DECRYPT_MODE, rubyKey, ivSpec)
assert rubyCipher.doFinal(rubyCipherBytes) == PLAINTEXT.bytes assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(rubyCipherBytes))
logger.sanity("Decrypted generated cipher text successfully") logger.sanity("Decrypted generated cipher text successfully")
assert rubyCipher.doFinal(cipherBytes) == PLAINTEXT.bytes assertArrayEquals(PLAINTEXT.bytes, rubyCipher.doFinal(cipherBytes))
logger.sanity("Decrypted external cipher text successfully") logger.sanity("Decrypted external cipher text successfully")
// n$r$p$hex_salt_SL$hex_hash_HL // n$r$p$hex_salt_SL$hex_hash_HL
@ -214,7 +218,7 @@ class ScryptCipherProviderGroovyTest {
// Convert hash from hex to Base64 // Convert hash from hex to Base64
String base64Hash = CipherUtility.encodeBase64NoPadding(Hex.decodeHex(hashHex as char[])) String base64Hash = CipherUtility.encodeBase64NoPadding(Hex.decodeHex(hashHex as char[]))
logger.info("Converted hash from hex ${hashHex} to Base64 ${base64Hash}") logger.info("Converted hash from hex ${hashHex} to Base64 ${base64Hash}")
assert Hex.encodeHexString(Base64.decodeBase64(base64Hash)) == hashHex assertEquals(hashHex, Hex.encodeHexString(Base64.decodeBase64(base64Hash)))
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}") logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
logger.info("External cipher text: ${CIPHER_TEXT} ${cipherBytes.length}") logger.info("External cipher text: ${CIPHER_TEXT} ${cipherBytes.length}")
@ -226,7 +230,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
@Test @Test
@ -272,7 +276,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
@Test @Test
@ -290,13 +294,12 @@ class ScryptCipherProviderGroovyTest {
INVALID_SALTS.each { String salt -> INVALID_SALTS.each { String salt ->
logger.info("Checking salt ${salt}") logger.info("Checking salt ${salt}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true))
} logger.warn(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ LENGTH_MESSAGE assertTrue(iae.getMessage().contains(LENGTH_MESSAGE))
} }
} }
@ -317,7 +320,7 @@ class ScryptCipherProviderGroovyTest {
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true) Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, salt.bytes, DEFAULT_KEY_LENGTH, true)
// Assert // Assert
assert cipher assertNotNull(cipher)
} }
} }
@ -330,13 +333,12 @@ class ScryptCipherProviderGroovyTest {
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}") logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()}")
// Act // Act
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, new byte[0], DEFAULT_KEY_LENGTH, true))
} logger.warn(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "The salt cannot be empty\\. To generate a salt, use ScryptCipherProvider#generateSalt" assertTrue(iae.getMessage().contains("The salt cannot be empty. To generate a salt, use ScryptCipherProvider#generateSalt"))
} }
@Test @Test
@ -357,13 +359,12 @@ class ScryptCipherProviderGroovyTest {
byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8")) byte[] cipherBytes = cipher.doFinal(PLAINTEXT.getBytes("UTF-8"))
logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}") logger.info("Cipher text: ${Hex.encodeHexString(cipherBytes)} ${cipherBytes.length}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipher = cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false) () -> cipherProvider.getCipher(em, PASSWORD, SALT, DEFAULT_KEY_LENGTH, false))
} logger.warn(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "Cannot decrypt without a valid IV" assertTrue(iae.getMessage().contains("Cannot decrypt without a valid IV"))
} }
} }
@ -394,7 +395,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Recovered: ${recovered}") logger.info("Recovered: ${recovered}")
// Assert // Assert
assert PLAINTEXT.equals(recovered) assertEquals(PLAINTEXT, recovered)
} }
} }
@ -415,13 +416,12 @@ class ScryptCipherProviderGroovyTest {
logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}") logger.info("Using algorithm: ${encryptionMethod.getAlgorithm()} with key length ${keyLength}")
// Initialize a cipher for encryption // Initialize a cipher for encryption
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Cipher cipher = cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true) () -> cipherProvider.getCipher(encryptionMethod, PASSWORD, SALT, IV, keyLength, true))
} logger.warn(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "${keyLength} is not a valid key length for AES" assertTrue(iae.getMessage().contains(keyLength + " is not a valid key length for AES"))
} }
} }
@ -434,12 +434,11 @@ class ScryptCipherProviderGroovyTest {
EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC
// Act // Act
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true) ()-> cipherProvider.getCipher(encryptionMethod, badPassword, salt, DEFAULT_KEY_LENGTH, true))
}
// Assert // Assert
assert msg =~ "Encryption with an empty password is not supported" assertTrue(iae.getMessage().contains("Encryption with an empty password is not supported"))
} }
@Test @Test
@ -455,9 +454,9 @@ class ScryptCipherProviderGroovyTest {
logger.info("Salt: ${salt}") logger.info("Salt: ${salt}")
// Assert // Assert
assert salt =~ "^(?i)\\\$s0\\\$[a-f0-9]{5,16}\\\$" assertTrue((salt =~ "^(?i)\\\$s0\\\$[a-f0-9]{5,16}\\\$").find())
String params = Scrypt.encodeParams(n, r, p) String params = Scrypt.encodeParams(n, r, p)
assert salt.contains("\$${params}\$") assertTrue(salt.contains("\$${params}\$"))
} }
@Test @Test
@ -480,10 +479,10 @@ class ScryptCipherProviderGroovyTest {
cipherProvider.parseSalt(FORMATTED_SALT, rawSalt, params) cipherProvider.parseSalt(FORMATTED_SALT, rawSalt, params)
// Assert // Assert
assert rawSalt == EXPECTED_RAW_SALT assertArrayEquals(EXPECTED_RAW_SALT, rawSalt)
assert params[0] == EXPECTED_N assertEquals(EXPECTED_N, params[0])
assert params[1] == EXPECTED_R assertEquals(EXPECTED_R, params[1])
assert params[2] == EXPECTED_P assertEquals(EXPECTED_P, params[2])
} }
@Test @Test
@ -496,7 +495,7 @@ class ScryptCipherProviderGroovyTest {
boolean valid = ScryptCipherProvider.isPValid(r, p) boolean valid = ScryptCipherProvider.isPValid(r, p)
// Assert // Assert
assert valid assertTrue(valid)
} }
@Test @Test
@ -504,19 +503,12 @@ class ScryptCipherProviderGroovyTest {
// Arrange // Arrange
// The p upper bound is calculated with the formula below, when r = 8: // The p upper bound is calculated with the formula below, when r = 8:
// pBoundary = ((Math.pow(2,32))-1) * (32.0/(r * 128)), where pBoundary = 134217727.96875; // pBoundary = ((Math.pow(2,32))-1) * (32.0/(r * 128)), where pBoundary = 134217727.96875;
Map costParameters = [8:134217729, 128:8388608, 4096: 0] Map<Integer, Integer> costParameters = [8:134217729, 128:8388608, 4096: 0]
// Act // Act and Assert
def results = costParameters.collectEntries { r, p -> costParameters.entrySet().forEach(entry -> {
def isValid = ScryptCipherProvider.isPValid(r, p) assertFalse(ScryptCipherProvider.isPValid(entry.getKey(), entry.getValue()))
[r, isValid] })
}
// Assert
results.each { r, isPValid ->
logger.info("For r ${r}, p is ${isPValid}")
assert !isPValid
}
} }
@Test @Test
@ -528,7 +520,7 @@ class ScryptCipherProviderGroovyTest {
boolean valid = ScryptCipherProvider.isRValid(r) boolean valid = ScryptCipherProvider.isRValid(r)
// Assert // Assert
assert valid assertTrue(valid)
} }
@Test @Test
@ -540,7 +532,7 @@ class ScryptCipherProviderGroovyTest {
boolean valid = ScryptCipherProvider.isRValid(r) boolean valid = ScryptCipherProvider.isRValid(r)
// Assert // Assert
assert !valid assertFalse(valid)
} }
@Test @Test
@ -554,7 +546,7 @@ class ScryptCipherProviderGroovyTest {
ScryptCipherProvider testCipherProvider = new ScryptCipherProvider(n, r, p) ScryptCipherProvider testCipherProvider = new ScryptCipherProvider(n, r, p)
// Assert // Assert
assert testCipherProvider assertNotNull(testCipherProvider)
} }
@Test @Test
@ -565,13 +557,12 @@ class ScryptCipherProviderGroovyTest {
final int p = 0 final int p = 0
// Act // Act
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
ScryptCipherProvider testCipherProvider = new ScryptCipherProvider(n, r, p) () -> new ScryptCipherProvider(n, r, p))
} logger.warn(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "Invalid p value exceeds p boundary" assertTrue(iae.getMessage().contains("Invalid p value exceeds p boundary"))
} }
@Test @Test
@ -582,13 +573,12 @@ class ScryptCipherProviderGroovyTest {
final int p = 0 final int p = 0
// Act // Act
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
ScryptCipherProvider testCipherProvider = new ScryptCipherProvider(n, r, p) () -> new ScryptCipherProvider(n, r, p))
} logger.warn(iae.getMessage())
logger.expected(msg)
// Assert // Assert
assert msg =~ "Invalid r value; must be greater than 0" assertTrue(iae.getMessage().contains("Invalid r value; must be greater than 0"))
} }
@Test @Test
@ -601,7 +591,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Is Scrypt salt: ${isScryptSalt}") logger.info("Is Scrypt salt: ${isScryptSalt}")
// Assert // Assert
assert isScryptSalt assertTrue(isScryptSalt)
} }
@EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true", @EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true",
@ -624,7 +614,7 @@ class ScryptCipherProviderGroovyTest {
logger.info("Determined minimum safe parameters to be N=${minimumN}, r=${minimumR}, p=${minimumP}") logger.info("Determined minimum safe parameters to be N=${minimumN}, r=${minimumR}, p=${minimumP}")
// Assert // Assert
assertTrue("The default parameters for ScryptCipherProvider are too weak. Please update the default values to a stronger level.", n >= minimumN) assertTrue(n >= minimumN, "The default parameters for ScryptCipherProvider are too weak. Please update the default values to a stronger level.")
} }
/** /**
@ -644,7 +634,7 @@ class ScryptCipherProviderGroovyTest {
int n = 2**4 int n = 2**4
int dkLen = 128 int dkLen = 128
assert Scrypt.calculateExpectedMemory(n, r, p) <= maxHeapSize assertTrue(Scrypt.calculateExpectedMemory(n, r, p) <= maxHeapSize)
byte[] salt = new byte[Scrypt.defaultSaltLength] byte[] salt = new byte[Scrypt.defaultSaltLength]
new SecureRandom().nextBytes(salt) new SecureRandom().nextBytes(salt)

View File

@ -22,7 +22,12 @@ import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertNotEquals
import static org.junit.jupiter.api.Assertions.assertThrows import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class ScryptSecureHasherTest { class ScryptSecureHasherTest {
@ -51,7 +56,7 @@ class ScryptSecureHasherTest {
} }
// Assert // Assert
assert results.every { it == EXPECTED_HASH_HEX } results.forEach( result -> assertEquals(EXPECTED_HASH_HEX, result))
} }
@Test @Test
@ -79,8 +84,8 @@ class ScryptSecureHasherTest {
} }
// Assert // Assert
assert results.unique().size() == results.size() assertTrue(results.unique().size() == results.size())
assert results.every { it != EXPECTED_HASH_HEX } results.forEach(result -> assertNotEquals(EXPECTED_HASH_HEX, result))
} }
@Test @Test
@ -122,20 +127,20 @@ class ScryptSecureHasherTest {
String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input) String differentSaltHashBase64 = arbitrarySaltHasher.hashBase64(input)
// Assert // Assert
assert staticSaltHash == EXPECTED_HASH_BYTES assertArrayEquals(EXPECTED_HASH_BYTES, staticSaltHash)
assert arbitrarySaltHash == EXPECTED_HASH_BYTES assertArrayEquals(EXPECTED_HASH_BYTES, arbitrarySaltHash)
assert differentArbitrarySaltHash != EXPECTED_HASH_BYTES assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentArbitrarySaltHash))
assert differentSaltHash != EXPECTED_HASH_BYTES assertFalse(Arrays.equals(EXPECTED_HASH_BYTES, differentSaltHash))
assert staticSaltHashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, staticSaltHashHex)
assert arbitrarySaltHashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, arbitrarySaltHashHex)
assert differentArbitrarySaltHashHex != EXPECTED_HASH_HEX assertNotEquals(EXPECTED_HASH_HEX, differentArbitrarySaltHashHex)
assert differentSaltHashHex != EXPECTED_HASH_HEX assertNotEquals(EXPECTED_HASH_HEX, differentSaltHashHex)
assert staticSaltHashBase64 == EXPECTED_HASH_BASE64 assertEquals(EXPECTED_HASH_BASE64, staticSaltHashBase64)
assert arbitrarySaltHashBase64 == EXPECTED_HASH_BASE64 assertEquals(EXPECTED_HASH_BASE64, arbitrarySaltHashBase64)
assert differentArbitrarySaltHashBase64 != EXPECTED_HASH_BASE64 assertNotEquals(EXPECTED_HASH_BASE64, differentArbitrarySaltHashBase64)
assert differentSaltHashBase64 != EXPECTED_HASH_BASE64 assertNotEquals(EXPECTED_HASH_BASE64, differentSaltHashBase64)
} }
@Test @Test
@ -172,7 +177,7 @@ class ScryptSecureHasherTest {
String hashHex = scryptSH.hashHex(input) String hashHex = scryptSH.hashHex(input)
// Assert // Assert
assert hashHex == EXPECTED_HASH_HEX assertEquals(EXPECTED_HASH_HEX, hashHex)
} }
@Test @Test
@ -188,7 +193,7 @@ class ScryptSecureHasherTest {
String hashB64 = scryptSH.hashBase64(input) String hashB64 = scryptSH.hashBase64(input)
// Assert // Assert
assert hashB64 == EXPECTED_HASH_BASE64 assertEquals(EXPECTED_HASH_BASE64, hashB64)
} }
@Test @Test
@ -214,8 +219,8 @@ class ScryptSecureHasherTest {
} }
// Assert // Assert
assert hexResults.every { it == EXPECTED_HASH_HEX } hexResults.forEach(result -> assertEquals(EXPECTED_HASH_HEX, result))
assert B64Results.every { it == EXPECTED_HASH_BASE64 } B64Results.forEach(result -> assertEquals(EXPECTED_HASH_BASE64, result))
} }
/** /**
@ -249,8 +254,8 @@ class ScryptSecureHasherTest {
// Assert // Assert
final long MIN_DURATION_NANOS = 75_000_000 // 75 ms final long MIN_DURATION_NANOS = 75_000_000 // 75 ms
assert resultDurations.min() > MIN_DURATION_NANOS assertTrue(resultDurations.min() > MIN_DURATION_NANOS)
assert resultDurations.sum() / testIterations > MIN_DURATION_NANOS assertTrue(resultDurations.sum() / testIterations > MIN_DURATION_NANOS)
} }
@Test @Test
@ -262,24 +267,16 @@ class ScryptSecureHasherTest {
boolean valid = ScryptSecureHasher.isRValid(r) boolean valid = ScryptSecureHasher.isRValid(r)
// Assert // Assert
assert valid assertTrue(valid)
} }
@Test @Test
void testShouldFailRBoundary() throws Exception { void testShouldFailRBoundary() throws Exception {
// Arrange // Arrange
def rValues = [-8, 0, 2147483647] List<Integer> rValues = [-8, 0, 2147483647]
// Act // Act and Assert
def results = rValues.collect { rValue -> rValues.forEach(rValue -> assertFalse(ScryptSecureHasher.isRValid(rValue)))
def isValid = ScryptSecureHasher.isRValid(rValue)
[rValue, isValid]
}
// Assert
results.each { rValue, isRValid ->
assert !isRValid
}
} }
@Test @Test
@ -288,28 +285,19 @@ class ScryptSecureHasherTest {
final Integer n = 16385 final Integer n = 16385
final int r = 8 final int r = 8
// Act // Act and Assert
boolean valid = ScryptSecureHasher.isNValid(n, r) assertTrue(ScryptSecureHasher.isNValid(n, r))
// Assert
assert valid
} }
@Test @Test
void testShouldFailNBoundary() throws Exception { void testShouldFailNBoundary() throws Exception {
// Arrange // Arrange
Map costParameters = [(-8): 8, 0: 32] Map<Integer, Integer> costParameters = [(-8): 8, 0: 32]
// Act //Act and Assert
def results = costParameters.collect { n, p -> costParameters.entrySet().forEach(entry -> {
def isValid = ScryptSecureHasher.isNValid(n, p) assertFalse(ScryptSecureHasher.isNValid(entry.getKey(), entry.getValue()))
[n, isValid] })
}
// Assert
results.each { n, isNValid ->
assert !isNValid
}
} }
@Test @Test
@ -318,19 +306,12 @@ class ScryptSecureHasherTest {
final List<Integer> ps = [1, 8, 1024] final List<Integer> ps = [1, 8, 1024]
final List<Integer> rs = [8, 1024, 4096] final List<Integer> rs = [8, 1024, 4096]
// Act // Act and Assert
def pResults = ps.collectEntries { int p -> ps.forEach(p -> {
def rResults = rs.collectEntries { int r -> rs.forEach(r -> {
boolean valid = ScryptSecureHasher.isPValid(p, r) assertTrue(ScryptSecureHasher.isPValid(p, r))
[r, valid] })
} })
[p, rResults]
}
// Assert
pResults.each { p, rResult ->
assert rResult.every { r, isValid -> isValid }
}
} }
@Test @Test
@ -339,19 +320,12 @@ class ScryptSecureHasherTest {
final List<Integer> ps = [4096 * 64, 1024 * 1024] final List<Integer> ps = [4096 * 64, 1024 * 1024]
final List<Integer> rs = [4096, 1024 * 1024] final List<Integer> rs = [4096, 1024 * 1024]
// Act // Act and Assert
def pResults = ps.collectEntries { int p -> ps.forEach(p -> {
def rResults = rs.collectEntries { int r -> rs.forEach(r -> {
boolean valid = ScryptSecureHasher.isPValid(p, r) assertFalse(ScryptSecureHasher.isPValid(p, r))
[r, valid] })
} })
[p, rResults]
}
// Assert
pResults.each { p, rResult ->
assert rResult.every { r, isValid -> !isValid }
}
} }
@Test @Test
@ -363,7 +337,7 @@ class ScryptSecureHasherTest {
boolean valid = ScryptSecureHasher.isDKLengthValid(dkLength) boolean valid = ScryptSecureHasher.isDKLengthValid(dkLength)
// Assert // Assert
assert valid assertTrue(valid)
} }
@Test @Test
@ -371,16 +345,10 @@ class ScryptSecureHasherTest {
// Arrange // Arrange
def dKLengths = [-8, 0, 2147483647] def dKLengths = [-8, 0, 2147483647]
// Act // Act and Assert
def results = dKLengths.collect { dKLength -> dKLengths.forEach( dKLength -> {
def isValid = ScryptSecureHasher.isDKLengthValid(dKLength) assertFalse(ScryptSecureHasher.isDKLengthValid(dKLength))
[dKLength, isValid] })
}
// Assert
results.each { dKLength, isDKLengthValid ->
assert !isDKLengthValid
}
} }
@Test @Test
@ -388,16 +356,11 @@ class ScryptSecureHasherTest {
// Arrange // Arrange
def saltLengths = [0, 64] def saltLengths = [0, 64]
// Act // Act and Assert
def results = saltLengths.collect { saltLength -> ScryptSecureHasher scryptSecureHasher = new ScryptSecureHasher()
def isValid = new ScryptSecureHasher().isSaltLengthValid(saltLength) saltLengths.forEach(saltLength -> {
[saltLength, isValid] assertTrue(scryptSecureHasher.isSaltLengthValid(saltLength))
} })
// Assert
results.each { saltLength, isSaltLengthValid ->
assert { it == isSaltLengthValid }
}
} }
@Test @Test
@ -405,16 +368,10 @@ class ScryptSecureHasherTest {
// Arrange // Arrange
def saltLengths = [-8, 1, 2147483647] def saltLengths = [-8, 1, 2147483647]
// Act // Act and Assert
def results = saltLengths.collect { saltLength -> ScryptSecureHasher scryptSecureHasher = new ScryptSecureHasher()
def isValid = new ScryptSecureHasher().isSaltLengthValid(saltLength) saltLengths.forEach(saltLength -> {
[saltLength, isValid] assertFalse(scryptSecureHasher.isSaltLengthValid(saltLength))
} })
// Assert
results.each { saltLength, isSaltLengthValid ->
assert !isSaltLengthValid
}
} }
} }

View File

@ -20,7 +20,6 @@ import org.apache.commons.codec.binary.Hex
import org.apache.nifi.security.util.crypto.scrypt.Scrypt import org.apache.nifi.security.util.crypto.scrypt.Scrypt
import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.condition.EnabledIfSystemProperty import org.junit.jupiter.api.condition.EnabledIfSystemProperty
import org.slf4j.Logger import org.slf4j.Logger
@ -29,7 +28,11 @@ import org.slf4j.LoggerFactory
import java.security.SecureRandom import java.security.SecureRandom
import java.security.Security import java.security.Security
import static groovy.test.GroovyAssert.shouldFail import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertFalse
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
import static org.junit.jupiter.api.Assumptions.assumeTrue import static org.junit.jupiter.api.Assumptions.assumeTrue
class ScryptGroovyTest { class ScryptGroovyTest {
@ -71,8 +74,8 @@ class ScryptGroovyTest {
} }
// Assert // Assert
assert allKeys.size() == RUNS assertEquals(RUNS, allKeys.size())
assert allKeys.every { it == allKeys.first() } allKeys.forEach(key -> assertArrayEquals(allKeys.first(), key))
} }
/** /**
@ -122,7 +125,7 @@ class ScryptGroovyTest {
logger.info("Generated ${Hex.encodeHexString(calculatedHash)}") logger.info("Generated ${Hex.encodeHexString(calculatedHash)}")
// Assert // Assert
assert calculatedHash == params.hash assertArrayEquals(params.hash, calculatedHash)
} }
} }
@ -163,7 +166,7 @@ class ScryptGroovyTest {
logger.info("Generated ${Hex.encodeHexString(calculatedHash)}") logger.info("Generated ${Hex.encodeHexString(calculatedHash)}")
// Assert // Assert
assert calculatedHash == HASH assertArrayEquals(HASH, calculatedHash)
} }
@EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true") @EnabledIfSystemProperty(named = "nifi.test.unstable", matches = "true")
@ -203,7 +206,7 @@ class ScryptGroovyTest {
logger.info("Generated ${Hex.encodeHexString(calculatedHash)}") logger.info("Generated ${Hex.encodeHexString(calculatedHash)}")
// Assert // Assert
assert calculatedHash == Hex.decodeHex(EXPECTED_KEY_HEX as char[]) assertArrayEquals(Hex.decodeHex(EXPECTED_KEY_HEX as char[]), calculatedHash)
} }
@Test @Test
@ -222,8 +225,8 @@ class ScryptGroovyTest {
} }
// Assert // Assert
assert allHashes.size() == RUNS assertEquals(RUNS, allHashes.size())
assert allHashes.every { it == allHashes.first() } allHashes.forEach(hash -> assertEquals(allHashes.first(), hash))
} }
@Test @Test
@ -231,14 +234,14 @@ class ScryptGroovyTest {
// Arrange // Arrange
// The generated salt should be byte[16], encoded as 22 Base64 chars // The generated salt should be byte[16], encoded as 22 Base64 chars
final def EXPECTED_SALT_PATTERN = /\$.+\$[0-9a-zA-Z\/\+]{22}\$.+/ final EXPECTED_SALT_PATTERN = /\$.+\$[0-9a-zA-Z\/\+]{22}\$.+/
// Act // Act
String calculatedHash = Scrypt.scrypt(PASSWORD, N, R, P, DK_LEN) String calculatedHash = Scrypt.scrypt(PASSWORD, N, R, P, DK_LEN)
logger.info("Generated ${calculatedHash}") logger.info("Generated ${calculatedHash}")
// Assert // Assert
assert calculatedHash =~ EXPECTED_SALT_PATTERN assertTrue((calculatedHash =~ EXPECTED_SALT_PATTERN).matches())
} }
@Test @Test
@ -254,12 +257,11 @@ class ScryptGroovyTest {
INVALID_NS.each { int invalidN -> INVALID_NS.each { int invalidN ->
logger.info("Using N: ${invalidN}") logger.info("Using N: ${invalidN}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, invalidN, R, P, DK_LEN) () -> Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, invalidN, R, P, DK_LEN))
}
// Assert // Assert
assert msg =~ "N must be a power of 2 greater than 1|Parameter N is too large" assertTrue((iae.getMessage() =~ "N must be a power of 2 greater than 1|Parameter N is too large").matches())
} }
} }
@ -278,13 +280,11 @@ class ScryptGroovyTest {
INVALID_RS.each { int invalidR -> INVALID_RS.each { int invalidR ->
logger.info("Using r: ${invalidR}") logger.info("Using r: ${invalidR}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
byte[] hash = Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, invalidR, largeP, DK_LEN) () -> Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, invalidR, largeP, DK_LEN))
logger.info("Generated hash: ${Hex.encodeHexString(hash)}")
}
// Assert // Assert
assert msg =~ "Parameter r must be 1 or greater|Parameter r is too large" assertTrue((iae.getMessage() =~ "Parameter r must be 1 or greater|Parameter r is too large").matches())
} }
} }
@ -300,13 +300,11 @@ class ScryptGroovyTest {
INVALID_PS.each { int invalidP -> INVALID_PS.each { int invalidP ->
logger.info("Using p: ${invalidP}") logger.info("Using p: ${invalidP}")
def msg = shouldFail(IllegalArgumentException) { IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
byte[] hash = Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, R, invalidP, DK_LEN) () -> Scrypt.deriveScryptKey(PASSWORD.bytes, SALT_BYTES, N, R, invalidP, DK_LEN))
logger.info("Generated hash: ${Hex.encodeHexString(hash)}")
}
// Assert // Assert
assert msg =~ "Parameter p must be 1 or greater|Parameter p is too large" assertTrue((iae.getMessage() =~ "Parameter p must be 1 or greater|Parameter p is too large").matches())
} }
} }
@ -322,7 +320,7 @@ class ScryptGroovyTest {
logger.info("Check matches: ${matches}") logger.info("Check matches: ${matches}")
// Assert // Assert
assert matches assertTrue(matches)
} }
@Test @Test
@ -337,7 +335,7 @@ class ScryptGroovyTest {
logger.info("Check matches: ${matches}") logger.info("Check matches: ${matches}")
// Assert // Assert
assert !matches assertFalse(matches)
} }
@Test @Test
@ -351,14 +349,12 @@ class ScryptGroovyTest {
// Act // Act
INVALID_PASSWORDS.each { String invalidPassword -> INVALID_PASSWORDS.each { String invalidPassword ->
logger.info("Using password: ${invalidPassword}") logger.info("Using password: ${invalidPassword}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
def msg = shouldFail(IllegalArgumentException) { () -> Scrypt.check(invalidPassword, HASH))
boolean matches = Scrypt.check(invalidPassword, HASH) logger.expected(iae.getMessage())
}
logger.expected(msg)
// Assert // Assert
assert msg =~ "Password cannot be empty" assertTrue(iae.getMessage().contains("Password cannot be empty"))
} }
} }
@ -373,14 +369,12 @@ class ScryptGroovyTest {
// Act // Act
INVALID_HASHES.each { String invalidHash -> INVALID_HASHES.each { String invalidHash ->
logger.info("Using hash: ${invalidHash}") logger.info("Using hash: ${invalidHash}")
IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
def msg = shouldFail(IllegalArgumentException) { () -> Scrypt.check(PASSWORD, invalidHash))
boolean matches = Scrypt.check(PASSWORD, invalidHash) logger.expected(iae.getMessage())
}
logger.expected(msg)
// Assert // Assert
assert msg =~ "Hash cannot be empty|Hash is not properly formatted" assertTrue((iae.getMessage() =~ "Hash cannot be empty|Hash is not properly formatted").matches())
} }
} }
@ -401,7 +395,8 @@ class ScryptGroovyTest {
"\$s0\$F0801\$AAAAAAAAAAA\$A", "\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$A", "\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$A",
"\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP", "\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP",
"\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP", "\$s0\$40801\$ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP\$" +
"ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKLMNOP",
"\$s0\$F0801\$AAAAAAAAAAA\$A", "\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$F0801\$AAAAAAAAAAA\$A", "\$s0\$F0801\$AAAAAAAAAAA\$A",
"\$s0\$F0801\$AAAAAAAAAAA\$A", "\$s0\$F0801\$AAAAAAAAAAA\$A",
@ -417,7 +412,7 @@ class ScryptGroovyTest {
logger.info("Hash is valid: ${isValidHash}") logger.info("Hash is valid: ${isValidHash}")
// Assert // Assert
assert isValidHash assertTrue(isValidHash)
} }
} }
@ -436,7 +431,7 @@ class ScryptGroovyTest {
logger.info("Hash is valid: ${isValidHash}") logger.info("Hash is valid: ${isValidHash}")
// Assert // Assert
assert !isValidHash assertFalse(isValidHash)
} }
} }
} }

View File

@ -38,9 +38,10 @@ import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.UUID; import java.util.UUID;
import static org.junit.Assert.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class KeyStoreUtilsTest { public class KeyStoreUtilsTest {
private static final String SIGNING_ALGORITHM = "SHA256withRSA"; private static final String SIGNING_ALGORITHM = "SHA256withRSA";
@ -89,18 +90,18 @@ public class KeyStoreUtilsTest {
); );
final TlsConfiguration configuration = KeyStoreUtils.createTlsConfigAndNewKeystoreTruststore(requested, 1, new String[] { HOSTNAME }); final TlsConfiguration configuration = KeyStoreUtils.createTlsConfigAndNewKeystoreTruststore(requested, 1, new String[] { HOSTNAME });
final File keystoreFile = new File(configuration.getKeystorePath()); final File keystoreFile = new File(configuration.getKeystorePath());
assertTrue("Keystore File not found", keystoreFile.exists()); assertTrue(keystoreFile.exists(), "Keystore File not found");
keystoreFile.deleteOnExit(); keystoreFile.deleteOnExit();
final File truststoreFile = new File(configuration.getTruststorePath()); final File truststoreFile = new File(configuration.getTruststorePath());
assertTrue("Truststore File not found", truststoreFile.exists()); assertTrue(truststoreFile.exists(),"Truststore File not found");
truststoreFile.deleteOnExit(); truststoreFile.deleteOnExit();
assertEquals("Keystore Type not matched", KeystoreType.PKCS12, configuration.getKeystoreType()); assertEquals(KeystoreType.PKCS12, configuration.getKeystoreType(), "Keystore Type not matched");
assertEquals("Truststore Type not matched", KeystoreType.PKCS12, configuration.getTruststoreType()); assertEquals(KeystoreType.PKCS12, configuration.getTruststoreType(), "Truststore Type not matched");
assertTrue("Keystore not valid", KeyStoreUtils.isStoreValid(keystoreFile.toURI().toURL(), configuration.getKeystoreType(), configuration.getKeystorePassword().toCharArray())); assertTrue(KeyStoreUtils.isStoreValid(keystoreFile.toURI().toURL(), configuration.getKeystoreType(), configuration.getKeystorePassword().toCharArray()), "Keystore not valid");
assertTrue("Truststore not valid", KeyStoreUtils.isStoreValid(truststoreFile.toURI().toURL(), configuration.getTruststoreType(), configuration.getTruststorePassword().toCharArray())); assertTrue(KeyStoreUtils.isStoreValid(truststoreFile.toURI().toURL(), configuration.getTruststoreType(), configuration.getTruststorePassword().toCharArray()), "Truststore not valid");
} }
@Test @Test
@ -141,7 +142,7 @@ public class KeyStoreUtilsTest {
sourceKeyStore.setCertificateEntry(ALIAS, certificate); sourceKeyStore.setCertificateEntry(ALIAS, certificate);
final KeyStore copiedKeyStore = copyKeyStore(sourceKeyStore, destinationKeyStore); final KeyStore copiedKeyStore = copyKeyStore(sourceKeyStore, destinationKeyStore);
assertEquals(String.format("[%s] Certificate not matched", sourceKeyStore.getType()), certificate, copiedKeyStore.getCertificate(ALIAS)); assertEquals(certificate, copiedKeyStore.getCertificate(ALIAS), String.format("[%s] Certificate not matched", sourceKeyStore.getType()));
} }
private void assertKeyEntryStoredLoaded(final KeyStore sourceKeyStore, final KeyStore destinationKeyStore) throws GeneralSecurityException, IOException { private void assertKeyEntryStoredLoaded(final KeyStore sourceKeyStore, final KeyStore destinationKeyStore) throws GeneralSecurityException, IOException {
@ -151,13 +152,13 @@ public class KeyStoreUtilsTest {
final KeyStore copiedKeyStore = copyKeyStore(sourceKeyStore, destinationKeyStore); final KeyStore copiedKeyStore = copyKeyStore(sourceKeyStore, destinationKeyStore);
final KeyStore.Entry entry = copiedKeyStore.getEntry(ALIAS, new KeyStore.PasswordProtection(KEY_PASSWORD)); final KeyStore.Entry entry = copiedKeyStore.getEntry(ALIAS, new KeyStore.PasswordProtection(KEY_PASSWORD));
assertTrue(String.format("[%s] Private Key entry not found", sourceKeyStore.getType()), entry instanceof KeyStore.PrivateKeyEntry); assertInstanceOf(KeyStore.PrivateKeyEntry.class, entry, String.format("[%s] Private Key entry not found", sourceKeyStore.getType()));
final KeyStore.PrivateKeyEntry privateKeyEntry = (KeyStore.PrivateKeyEntry) entry; final KeyStore.PrivateKeyEntry privateKeyEntry = (KeyStore.PrivateKeyEntry) entry;
final Certificate[] entryCertificateChain = privateKeyEntry.getCertificateChain(); final Certificate[] entryCertificateChain = privateKeyEntry.getCertificateChain();
assertArrayEquals(String.format("[%s] Certificate Chain not matched", sourceKeyStore.getType()), certificateChain, entryCertificateChain); assertArrayEquals(certificateChain, entryCertificateChain, String.format("[%s] Certificate Chain not matched", sourceKeyStore.getType()));
assertEquals(String.format("[%s] Private Key not matched", sourceKeyStore.getType()), keyPair.getPrivate(), privateKeyEntry.getPrivateKey()); assertEquals(keyPair.getPrivate(), privateKeyEntry.getPrivateKey(), String.format("[%s] Private Key not matched", sourceKeyStore.getType()));
assertEquals(String.format("[%s] Public Key not matched", sourceKeyStore.getType()), keyPair.getPublic(), entryCertificateChain[0].getPublicKey()); assertEquals(keyPair.getPublic(), entryCertificateChain[0].getPublicKey(), String.format("[%s] Public Key not matched", sourceKeyStore.getType()));
} }
private void assertSecretKeyStoredLoaded(final KeyStore sourceKeyStore, final KeyStore destinationKeyStore) throws GeneralSecurityException, IOException { private void assertSecretKeyStoredLoaded(final KeyStore sourceKeyStore, final KeyStore destinationKeyStore) throws GeneralSecurityException, IOException {
@ -167,7 +168,7 @@ public class KeyStoreUtilsTest {
final KeyStore copiedKeyStore = copyKeyStore(sourceKeyStore, destinationKeyStore); final KeyStore copiedKeyStore = copyKeyStore(sourceKeyStore, destinationKeyStore);
final KeyStore.Entry entry = copiedKeyStore.getEntry(ALIAS, protection); final KeyStore.Entry entry = copiedKeyStore.getEntry(ALIAS, protection);
assertTrue(String.format("[%s] Secret Key entry not found", sourceKeyStore.getType()), entry instanceof KeyStore.SecretKeyEntry); assertInstanceOf(KeyStore.SecretKeyEntry.class, entry, String.format("[%s] Secret Key entry not found", sourceKeyStore.getType()));
} }
private KeyStore copyKeyStore(final KeyStore sourceKeyStore, final KeyStore destinationKeyStore) throws GeneralSecurityException, IOException { private KeyStore copyKeyStore(final KeyStore sourceKeyStore, final KeyStore destinationKeyStore) throws GeneralSecurityException, IOException {

View File

@ -18,7 +18,7 @@ package org.apache.nifi.security.util.crypto;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class SecureHasherFactoryTest { public class SecureHasherFactoryTest {

View File

@ -33,7 +33,12 @@ import org.slf4j.LoggerFactory
import java.security.Security import java.security.Security
import java.util.concurrent.ArrayBlockingQueue import java.util.concurrent.ArrayBlockingQueue
class PeerSelectorTest extends GroovyTestCase { import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertNotNull
import static org.junit.jupiter.api.Assertions.assertNull
import static org.junit.jupiter.api.Assertions.assertTrue
class PeerSelectorTest {
private static final Logger logger = LoggerFactory.getLogger(PeerSelectorTest.class) private static final Logger logger = LoggerFactory.getLogger(PeerSelectorTest.class)
private static final BOOTSTRAP_PEER_DESCRIPTION = new PeerDescription("localhost", -1, false) private static final BOOTSTRAP_PEER_DESCRIPTION = new PeerDescription("localhost", -1, false)
@ -135,7 +140,7 @@ class PeerSelectorTest extends GroovyTestCase {
final Map<String, Double> EXPECTED_PERCENTS, final Map<String, Double> EXPECTED_PERCENTS,
final int NUM_TIMES = resultsFrequency.values().sum() as int, final int NUM_TIMES = resultsFrequency.values().sum() as int,
final double TOLERANCE = 0.05) { final double TOLERANCE = 0.05) {
assert resultsFrequency.keySet() == EXPECTED_PERCENTS.keySet() assertEquals(EXPECTED_PERCENTS.keySet(), resultsFrequency.keySet())
logger.info(" Actual results: ${resultsFrequency.sort()}") logger.info(" Actual results: ${resultsFrequency.sort()}")
logger.info("Expected results: ${EXPECTED_PERCENTS.sort().collect { k, v -> "${k}: ${v}%" }}") logger.info("Expected results: ${EXPECTED_PERCENTS.sort().collect { k, v -> "${k}: ${v}%" }}")
@ -154,7 +159,7 @@ class PeerSelectorTest extends GroovyTestCase {
def count = resultsFrequency[k] def count = resultsFrequency[k]
def difference = Math.abs(expectedCount - count) / NUM_TIMES def difference = Math.abs(expectedCount - count) / NUM_TIMES
logger.debug("Checking that ${count} is within ±${TOLERANCE * 100}% of ${expectedCount} (${lowerBound}, ${upperBound}) | ${(difference * 100).round(2)}%") logger.debug("Checking that ${count} is within ±${TOLERANCE * 100}% of ${expectedCount} (${lowerBound}, ${upperBound}) | ${(difference * 100).round(2)}%")
assert count >= lowerBound && count <= upperBound assertTrue(count >= lowerBound && count <= upperBound)
} }
} }
@ -169,7 +174,7 @@ class PeerSelectorTest extends GroovyTestCase {
int consecutiveElements = recentPeerSelectionQueue.getMaxConsecutiveElements() int consecutiveElements = recentPeerSelectionQueue.getMaxConsecutiveElements()
// String mcce = recentPeerSelectionQueue.getMostCommonConsecutiveElement() // String mcce = recentPeerSelectionQueue.getMostCommonConsecutiveElement()
// logger.debug("Most consecutive elements in recentPeerSelectionQueue: ${consecutiveElements} - ${mcce} | ${recentPeerSelectionQueue}") // logger.debug("Most consecutive elements in recentPeerSelectionQueue: ${consecutiveElements} - ${mcce} | ${recentPeerSelectionQueue}")
assert consecutiveElements <= recentPeerSelectionQueue.totalSize - 1 assertTrue(consecutiveElements <= recentPeerSelectionQueue.totalSize - 1)
} }
private static double calculateMean(Map resultsFrequency) { private static double calculateMean(Map resultsFrequency) {
@ -179,7 +184,9 @@ class PeerSelectorTest extends GroovyTestCase {
return meanElements.sum() / meanElements.size() return meanElements.sum() / meanElements.size()
} }
private static PeerStatusProvider mockPeerStatusProvider(PeerDescription bootstrapPeerDescription = BOOTSTRAP_PEER_DESCRIPTION, String remoteInstanceUris = DEFAULT_REMOTE_INSTANCE_URIS, Map<PeerDescription, Set<PeerStatus>> peersMap = DEFAULT_PEER_NODES) { private static PeerStatusProvider mockPeerStatusProvider(PeerDescription bootstrapPeerDescription = BOOTSTRAP_PEER_DESCRIPTION,
String remoteInstanceUris = DEFAULT_REMOTE_INSTANCE_URIS,
Map<PeerDescription, Set<PeerStatus>> peersMap = DEFAULT_PEER_NODES) {
[getTransportProtocol : { -> [getTransportProtocol : { ->
SiteToSiteTransportProtocol.HTTP SiteToSiteTransportProtocol.HTTP
}, },
@ -230,8 +237,8 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}") logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// Assert // Assert
assert peersToQuery.size() == 1 assertEquals(1, peersToQuery.size())
assert peersToQuery.first() == BOOTSTRAP_PEER_DESCRIPTION assertEquals(BOOTSTRAP_PEER_DESCRIPTION, peersToQuery.first())
} }
@Test @Test
@ -249,9 +256,9 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}") logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// Assert // Assert
assert peersToQuery.size() == restoredPeerStatuses.size() + 1 assertEquals(restoredPeerStatuses.size() + 1, peersToQuery.size())
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION) assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
assert peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS) assertTrue(peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS))
} }
/** /**
@ -275,11 +282,11 @@ class PeerSelectorTest extends GroovyTestCase {
} }
// Assert // Assert
assert peersToQuery.size() == DEFAULT_PEER_STATUSES.size() + 1 assertEquals(DEFAULT_PEER_STATUSES.size() + 1, peersToQuery.size())
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION) assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
assert peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS) assertTrue(peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS))
assert repeatedPeersToQuery.every { it == peersToQuery } repeatedPeersToQuery.forEach(query -> assertEquals(peersToQuery, query))
} }
@Test @Test
@ -292,8 +299,8 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Retrieved ${remotePeerStatuses.size()} peer statuses: ${remotePeerStatuses}") logger.info("Retrieved ${remotePeerStatuses.size()} peer statuses: ${remotePeerStatuses}")
// Assert // Assert
assert remotePeerStatuses.size() == DEFAULT_PEER_STATUSES.size() assertEquals(DEFAULT_PEER_STATUSES.size(), remotePeerStatuses.size())
assert remotePeerStatuses.containsAll(DEFAULT_PEER_STATUSES) assertTrue(remotePeerStatuses.containsAll(DEFAULT_PEER_STATUSES))
} }
/** /**
@ -339,10 +346,10 @@ class PeerSelectorTest extends GroovyTestCase {
} }
// Assert that the send percentage is always between 0% and 80% // Assert that the send percentage is always between 0% and 80%
assert r.every { k, v -> v.send >= 0 && v.send <= 80 } r.every { k, v -> assertTrue(v.send >= 0 && v.send <= 80) }
// Assert that the receive percentage is always between 0% and 100% // Assert that the receive percentage is always between 0% and 100%
assert r.every { k, v -> v.receive >= 0 && v.receive <= 100 } r.every { k, v -> assertTrue(v.receive >= 0 && v.receive <= 100) }
} }
} }
} }
@ -365,8 +372,8 @@ class PeerSelectorTest extends GroovyTestCase {
double receiveWeight = PeerSelector.calculateNormalizedWeight(TransferDirection.RECEIVE, totalFlowfileCount, flowfileCount, NODE_COUNT) double receiveWeight = PeerSelector.calculateNormalizedWeight(TransferDirection.RECEIVE, totalFlowfileCount, flowfileCount, NODE_COUNT)
// Assert // Assert
assert sendWeight == 100 assertEquals(100, sendWeight)
assert receiveWeight == 100 assertEquals(100, receiveWeight)
} }
} }
} }
@ -392,7 +399,7 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Weighted peer map: ${weightedPeerMap}") logger.info("Weighted peer map: ${weightedPeerMap}")
// Assert // Assert
assert new ArrayList<>(weightedPeerMap.keySet()) == new ArrayList(clusterMap.keySet()) assertEquals(clusterMap.keySet(), weightedPeerMap.keySet())
} }
@Test @Test
@ -416,7 +423,7 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Weighted peer map: ${weightedPeerMap}") logger.info("Weighted peer map: ${weightedPeerMap}")
// Assert // Assert
assert new ArrayList<>(weightedPeerMap.keySet()) == new ArrayList(clusterMap.keySet()) assertEquals(clusterMap.keySet(), weightedPeerMap.keySet())
} }
/** /**
@ -450,11 +457,11 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Destination map: ${destinationMap}") logger.info("Destination map: ${destinationMap}")
// Assert // Assert
assert destinationMap.keySet() == peerStatuses assertEquals(peerStatuses, destinationMap.keySet())
// For uneven splits, the resulting percentage should be within +/- 1% // For uneven splits, the resulting percentage should be within +/- 1%
def totalPercentage = destinationMap.values().sum() def totalPercentage = destinationMap.values().sum()
assert totalPercentage >= 99 && totalPercentage <= 100 assertTrue(totalPercentage >= 99 && totalPercentage <= 100)
} }
} }
} }
@ -603,7 +610,7 @@ class PeerSelectorTest extends GroovyTestCase {
// Spot check consecutive selection // Spot check consecutive selection
if (i % 10 == 0) { if (i % 10 == 0) {
int consecutiveElements = lastN.getMaxConsecutiveElements() int consecutiveElements = lastN.getMaxConsecutiveElements()
assert consecutiveElements == lastN.size() assertEquals(lastN.size(), consecutiveElements)
} }
} }
@ -784,7 +791,8 @@ class PeerSelectorTest extends GroovyTestCase {
cacheFile.deleteOnExit() cacheFile.deleteOnExit()
// Construct the cache contents and write to disk // Construct the cache contents and write to disk
final String CACHE_CONTENTS = "${mockPSP.getTransportProtocol()}\n" + "${AbstractPeerPersistence.REMOTE_INSTANCE_URIS_PREFIX}${mockPSP.getRemoteInstanceUris()}\n" + peerStatuses.collect { PeerStatus ps -> final String CACHE_CONTENTS = "${mockPSP.getTransportProtocol()}\n" +
"${AbstractPeerPersistence.REMOTE_INSTANCE_URIS_PREFIX}${mockPSP.getRemoteInstanceUris()}\n" + peerStatuses.collect { PeerStatus ps ->
[ps.peerDescription.hostname, ps.peerDescription.port, ps.peerDescription.isSecure(), ps.isQueryForPeers()].join(":") [ps.peerDescription.hostname, ps.peerDescription.port, ps.peerDescription.isSecure(), ps.isQueryForPeers()].join(":")
}.join("\n") }.join("\n")
cacheFile.text = CACHE_CONTENTS cacheFile.text = CACHE_CONTENTS
@ -801,9 +809,9 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}") logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// Assert // Assert
assert peersToQuery.size() == nodes.size() + 1 assertEquals(nodes.size() + 1, peersToQuery.size())
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION) assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
assert peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS) assertTrue(peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS))
} }
/** /**
@ -824,7 +832,8 @@ class PeerSelectorTest extends GroovyTestCase {
cacheFile.deleteOnExit() cacheFile.deleteOnExit()
// Construct the cache contents and write to disk // Construct the cache contents and write to disk
final String CACHE_CONTENTS = "${mockPSP.getTransportProtocol()}\n" + "${AbstractPeerPersistence.REMOTE_INSTANCE_URIS_PREFIX}${mockPSP.getRemoteInstanceUris()}\n" + peerStatuses.collect { PeerStatus ps -> final String CACHE_CONTENTS = "${mockPSP.getTransportProtocol()}\n" +
"${AbstractPeerPersistence.REMOTE_INSTANCE_URIS_PREFIX}${mockPSP.getRemoteInstanceUris()}\n" + peerStatuses.collect { PeerStatus ps ->
[ps.peerDescription.hostname, ps.peerDescription.port, ps.peerDescription.isSecure(), ps.isQueryForPeers()].join(":") [ps.peerDescription.hostname, ps.peerDescription.port, ps.peerDescription.isSecure(), ps.isQueryForPeers()].join(":")
}.join("\n") }.join("\n")
cacheFile.text = CACHE_CONTENTS cacheFile.text = CACHE_CONTENTS
@ -842,16 +851,16 @@ class PeerSelectorTest extends GroovyTestCase {
// Assert // Assert
// The loaded cache should be marked as expired and not used // The loaded cache should be marked as expired and not used
assert ps.isCacheExpired(ps.peerStatusCache) assertTrue(ps.isCacheExpired(ps.peerStatusCache))
// This internal method does not refresh or check expiration // This internal method does not refresh or check expiration
def peersToQuery = ps.getPeersToQuery() def peersToQuery = ps.getPeersToQuery()
logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}") logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// The cache has (expired) peer statuses present // The cache has (expired) peer statuses present
assert peersToQuery.size() == nodes.size() + 1 assertEquals(nodes.size() + 1, peersToQuery.size())
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION) assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
assert peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS) assertTrue(peersToQuery.containsAll(DEFAULT_PEER_DESCRIPTIONS))
// Trigger the cache expiration detection // Trigger the cache expiration detection
ps.refresh() ps.refresh()
@ -860,8 +869,8 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("After cache expiration, retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}") logger.info("After cache expiration, retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// The cache only contains the bootstrap node // The cache only contains the bootstrap node
assert peersToQuery.size() == 1 assertEquals(1, peersToQuery.size())
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION) assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
} }
Throwable generateException(String message, int nestedLevel = 0) { Throwable generateException(String message, int nestedLevel = 0) {
@ -895,8 +904,8 @@ class PeerSelectorTest extends GroovyTestCase {
def peersToQuery = ps.getPeersToQuery() def peersToQuery = ps.getPeersToQuery()
// Assert // Assert
assert peersToQuery.size() == 1 assertEquals(1, peersToQuery.size())
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION) assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
} }
/** /**
@ -931,8 +940,8 @@ class PeerSelectorTest extends GroovyTestCase {
logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}") logger.info("Retrieved ${peersToQuery.size()} peers to query: ${peersToQuery}")
// Assert // Assert
assert peersToQuery.size() == 1 assertEquals(1, peersToQuery.size())
assert peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION) assertTrue(peersToQuery.contains(BOOTSTRAP_PEER_DESCRIPTION))
} }
/** /**
@ -998,7 +1007,7 @@ class PeerSelectorTest extends GroovyTestCase {
ps.refresh() ps.refresh()
PeerStatus peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE) PeerStatus peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE)
logger.info("Attempt ${currentAttempt} - ${peerStatus}") logger.info("Attempt ${currentAttempt} - ${peerStatus}")
assert peerStatus assertNotNull(peerStatus)
// Force the selector to refresh the cache // Force the selector to refresh the cache
currentAttempt++ currentAttempt++
@ -1009,7 +1018,7 @@ class PeerSelectorTest extends GroovyTestCase {
ps.refresh() ps.refresh()
peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE) peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE)
logger.info("Attempt ${currentAttempt} - ${peerStatus}") logger.info("Attempt ${currentAttempt} - ${peerStatus}")
assert peerStatus == node2Status assertEquals(node2Status, peerStatus)
// Force the selector to refresh the cache // Force the selector to refresh the cache
currentAttempt++ currentAttempt++
@ -1020,7 +1029,7 @@ class PeerSelectorTest extends GroovyTestCase {
ps.refresh() ps.refresh()
peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE) peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE)
logger.info("Attempt ${currentAttempt} - ${peerStatus}") logger.info("Attempt ${currentAttempt} - ${peerStatus}")
assert !peerStatus assertNull(peerStatus)
// Force the selector to refresh the cache // Force the selector to refresh the cache
currentAttempt = 5 currentAttempt = 5
@ -1030,7 +1039,7 @@ class PeerSelectorTest extends GroovyTestCase {
ps.refresh() ps.refresh()
peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE) peerStatus = ps.getNextPeerStatus(TransferDirection.RECEIVE)
logger.info("Attempt ${currentAttempt} - ${peerStatus}") logger.info("Attempt ${currentAttempt} - ${peerStatus}")
assert peerStatus == bootstrapStatus assertEquals(bootstrapStatus, peerStatus)
} }
// PeerQueue definition and tests // PeerQueue definition and tests
@ -1052,7 +1061,7 @@ class PeerSelectorTest extends GroovyTestCase {
peerQueue.append(nodes.first()) peerQueue.append(nodes.first())
// Assert // Assert
assert peerQueue.getMaxConsecutiveElements() == peerQueue.size() assertEquals(peerQueue.size(), peerQueue.getMaxConsecutiveElements())
} }
// Never repeating node // Never repeating node
@ -1061,7 +1070,7 @@ class PeerSelectorTest extends GroovyTestCase {
peerQueue.append(nodes.get(i % peerStatuses.size())) peerQueue.append(nodes.get(i % peerStatuses.size()))
// Assert // Assert
assert peerQueue.getMaxConsecutiveElements() == 1 assertEquals(1, peerQueue.getMaxConsecutiveElements())
} }
// Repeat up to nodes.size() times but no more // Repeat up to nodes.size() times but no more
@ -1072,7 +1081,7 @@ class PeerSelectorTest extends GroovyTestCase {
// Assert // Assert
// logger.debug("Most consecutive elements in queue: ${peerQueue.getMaxConsecutiveElements()} | ${peerQueue}") // logger.debug("Most consecutive elements in queue: ${peerQueue.getMaxConsecutiveElements()} | ${peerQueue}")
assert peerQueue.getMaxConsecutiveElements() <= peerStatuses.size() assertTrue(peerQueue.getMaxConsecutiveElements() <= peerStatuses.size())
} }
} }

View File

@ -29,7 +29,10 @@ import org.slf4j.LoggerFactory
import javax.net.ssl.SSLServerSocket import javax.net.ssl.SSLServerSocket
import java.security.Security import java.security.Security
class SocketUtilsTest extends GroovyTestCase { import static org.junit.jupiter.api.Assertions.assertArrayEquals
import static org.junit.jupiter.api.Assertions.assertFalse
class SocketUtilsTest {
private static final Logger logger = LoggerFactory.getLogger(SocketUtilsTest.class) private static final Logger logger = LoggerFactory.getLogger(SocketUtilsTest.class)
private static final String KEYSTORE_PATH = "src/test/resources/TlsConfigurationKeystore.jks" private static final String KEYSTORE_PATH = "src/test/resources/TlsConfigurationKeystore.jks"
@ -56,7 +59,9 @@ class SocketUtilsTest extends GroovyTestCase {
private NiFiProperties mockNiFiProperties = NiFiProperties.createBasicNiFiProperties(null, DEFAULT_PROPS) private NiFiProperties mockNiFiProperties = NiFiProperties.createBasicNiFiProperties(null, DEFAULT_PROPS)
// A static TlsConfiguration referencing the test resource keystore and truststore // A static TlsConfiguration referencing the test resource keystore and truststore
// private static final TlsConfiguration TLS_CONFIGURATION = new StandardTlsConfiguration(KEYSTORE_PATH, KEYSTORE_PASSWORD, KEY_PASSWORD, KEYSTORE_TYPE, TRUSTSTORE_PATH, TRUSTSTORE_PASSWORD, TRUSTSTORE_TYPE, PROTOCOL) // private static final TlsConfiguration TLS_CONFIGURATION =
// new StandardTlsConfiguration(KEYSTORE_PATH, KEYSTORE_PASSWORD, KEY_PASSWORD, KEYSTORE_TYPE, TRUSTSTORE_PATH,
// TRUSTSTORE_PASSWORD, TRUSTSTORE_TYPE, PROTOCOL)
// private static final SSLContext sslContext = SslContextFactory.createSslContext(TLS_CONFIGURATION, ClientAuth.NONE) // private static final SSLContext sslContext = SslContextFactory.createSslContext(TLS_CONFIGURATION, ClientAuth.NONE)
@BeforeAll @BeforeAll
@ -81,9 +86,9 @@ class SocketUtilsTest extends GroovyTestCase {
// Assert // Assert
String[] enabledProtocols = sslServerSocket.getEnabledProtocols() String[] enabledProtocols = sslServerSocket.getEnabledProtocols()
logger.info("Enabled protocols: ${enabledProtocols}") logger.info("Enabled protocols: ${enabledProtocols}")
assert enabledProtocols == TlsConfiguration.getCurrentSupportedTlsProtocolVersions() assertArrayEquals(TlsConfiguration.getCurrentSupportedTlsProtocolVersions(), enabledProtocols)
assert !enabledProtocols.contains("TLSv1") assertFalse(enabledProtocols.contains("TLSv1"))
assert !enabledProtocols.contains("TLSv1.1") assertFalse(enabledProtocols.contains("TLSv1.1"))
} }
@Test @Test
@ -99,8 +104,8 @@ class SocketUtilsTest extends GroovyTestCase {
// Assert // Assert
String[] enabledProtocols = sslServerSocket.getEnabledProtocols() String[] enabledProtocols = sslServerSocket.getEnabledProtocols()
logger.info("Enabled protocols: ${enabledProtocols}") logger.info("Enabled protocols: ${enabledProtocols}")
assert enabledProtocols == TlsConfiguration.getCurrentSupportedTlsProtocolVersions() assertArrayEquals(TlsConfiguration.getCurrentSupportedTlsProtocolVersions(), enabledProtocols)
assert !enabledProtocols.contains("TLSv1") assertFalse(enabledProtocols.contains("TLSv1"))
assert !enabledProtocols.contains("TLSv1.1") assertFalse(enabledProtocols.contains("TLSv1.1"))
} }
} }

View File

@ -22,8 +22,13 @@ import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.stream.IntStream
class TestFormatUtilsGroovy extends GroovyTestCase { import static org.junit.jupiter.api.Assertions.assertEquals
import static org.junit.jupiter.api.Assertions.assertThrows
import static org.junit.jupiter.api.Assertions.assertTrue
class TestFormatUtilsGroovy {
private static final Logger logger = LoggerFactory.getLogger(TestFormatUtilsGroovy.class) private static final Logger logger = LoggerFactory.getLogger(TestFormatUtilsGroovy.class)
@BeforeAll @BeforeAll
@ -49,7 +54,7 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(days) logger.converted(days)
// Assert // Assert
assert days.every { it == EXPECTED_DAYS } days.forEach(it -> assertEquals(EXPECTED_DAYS, it))
} }
@ -59,14 +64,12 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
final List WEEKS = ["-1 week", "-1 wk", "-1 w", "-1 weeks", "- 1 week"] final List WEEKS = ["-1 week", "-1 wk", "-1 w", "-1 weeks", "- 1 week"]
// Act // Act
List msgs = WEEKS.collect { String week -> WEEKS.stream().forEach(week -> {
shouldFail(IllegalArgumentException) { IllegalArgumentException iae =
FormatUtils.getTimeDuration(week, TimeUnit.DAYS) assertThrows(IllegalArgumentException.class, () -> FormatUtils.getTimeDuration(week, TimeUnit.DAYS))
} // Assert
} assertTrue(iae.message.contains("Value '" + week + "' is not a valid time duration"))
})
// Assert
assert msgs.every { it =~ /Value '.*' is not a valid time duration/ }
} }
/** /**
@ -78,14 +81,12 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
final List WEEKS = ["1 work", "1 wek", "1 k"] final List WEEKS = ["1 work", "1 wek", "1 k"]
// Act // Act
List msgs = WEEKS.collect { String week -> WEEKS.stream().forEach(week -> {
shouldFail(IllegalArgumentException) { IllegalArgumentException iae =
FormatUtils.getTimeDuration(week, TimeUnit.DAYS) assertThrows(IllegalArgumentException.class, () -> FormatUtils.getTimeDuration(week, TimeUnit.DAYS))
} // Assert
} assertTrue(iae.message.contains("Value '" + week + "' is not a valid time duration"))
})
// Assert
assert msgs.every { it =~ /Value '.*' is not a valid time duration/ }
} }
@ -105,7 +106,7 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(days) logger.converted(days)
// Assert // Assert
assert days.every { it == EXPECTED_DAYS } days.forEach(it -> assertEquals(EXPECTED_DAYS, it))
} }
/** /**
@ -130,8 +131,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(parsedDecimalMillis) logger.converted(parsedDecimalMillis)
// Assert // Assert
assert parsedWholeMillis.every { it == EXPECTED_MILLIS } parsedWholeMillis.forEach(it -> assertEquals(EXPECTED_MILLIS, it))
assert parsedDecimalMillis.every { it == EXPECTED_MILLIS } parsedDecimalMillis.forEach(it -> assertEquals(EXPECTED_MILLIS, it))
} }
/** /**
@ -158,9 +159,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(oneWeekInOtherUnits) logger.converted(oneWeekInOtherUnits)
// Assert // Assert
oneWeekInOtherUnits.each { TimeUnit k, double value -> oneWeekInOtherUnits.entrySet().forEach(entry ->
assert value == ONE_WEEK_IN_OTHER_UNITS[k] assertEquals(ONE_WEEK_IN_OTHER_UNITS.get(entry.getKey()), entry.getValue()))
}
} }
/** /**
@ -187,9 +187,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(onePointFiveWeeksInOtherUnits) logger.converted(onePointFiveWeeksInOtherUnits)
// Assert // Assert
onePointFiveWeeksInOtherUnits.each { TimeUnit k, double value -> onePointFiveWeeksInOtherUnits.entrySet().forEach(entry ->
assert value == ONE_POINT_FIVE_WEEKS_IN_OTHER_UNITS[k] assertEquals(ONE_POINT_FIVE_WEEKS_IN_OTHER_UNITS.get(entry.getKey()), entry.getValue()))
}
} }
/** /**
@ -214,8 +213,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(parsedDecimalMillis) logger.converted(parsedDecimalMillis)
// Assert // Assert
assert parsedWholeMillis.every { it == EXPECTED_MILLIS } parsedWholeMillis.forEach(it -> assertEquals(EXPECTED_MILLIS, it))
assert parsedDecimalMillis.every { it == EXPECTED_MILLIS } parsedDecimalMillis.forEach(it -> assertEquals(EXPECTED_MILLIS, it))
} }
/** /**
@ -240,9 +239,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.info(results) logger.info(results)
// Assert // Assert
results.every { String key, double value -> results.entrySet().forEach(entry ->
assert value == SCENARIOS[key].expectedValue assertEquals(SCENARIOS.get(entry.getKey()).expectedValue, entry.getValue()))
}
} }
/** /**
@ -265,7 +263,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(parsedWholeNanos) logger.converted(parsedWholeNanos)
// Assert // Assert
assert parsedWholeNanos.every { it == [EXPECTED_NANOS, TimeUnit.NANOSECONDS] } parsedWholeNanos.forEach(it ->
assertEquals(Arrays.asList(EXPECTED_NANOS, TimeUnit.NANOSECONDS), it))
} }
/** /**
@ -291,8 +290,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
// Assert // Assert
results.every { String key, List values -> results.every { String key, List values ->
assert values.first() == SCENARIOS[key].expectedValue assertEquals(SCENARIOS[key].expectedValue, values.first())
assert values.last() == SCENARIOS[key].expectedUnits assertEquals(SCENARIOS[key].expectedUnits, values.last())
} }
} }
@ -316,10 +315,10 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.info(results) logger.info(results)
// Assert // Assert
results.every { String key, List values -> results.entrySet().forEach(entry -> {
assert values.first() == SCENARIOS[key].expectedValue assertEquals(SCENARIOS.get(entry.getKey()).expectedValue, entry.getValue().get(0))
assert values.last() == SCENARIOS[key].expectedUnits assertEquals(SCENARIOS.get(entry.getKey()).expectedUnits, entry.getValue().get(1))
} })
} }
/** /**
@ -339,17 +338,18 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
] ]
// Act // Act
List parsedWholeTimes = WHOLE_TIMES.collect { List it -> List<List<Object>> parsedWholeTimes = WHOLE_TIMES.collect { List it ->
FormatUtils.makeWholeNumberTime(it[0] as float, it[1] as TimeUnit) FormatUtils.makeWholeNumberTime(it[0] as float, it[1] as TimeUnit)
} }
logger.converted(parsedWholeTimes) logger.converted(parsedWholeTimes)
// Assert // Assert
parsedWholeTimes.eachWithIndex { List elements, int i -> IntStream.range(0, parsedWholeTimes.size())
assert elements[0] instanceof Long .forEach(index -> {
assert elements[0] == 10L List<Object> elements = parsedWholeTimes.get(index)
assert elements[1] == WHOLE_TIMES[i][1] assertEquals(10L, elements.get(0))
} assertEquals(WHOLE_TIMES[index][1], elements.get(1))
})
} }
/** /**
@ -379,7 +379,7 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(parsedWholeTimes) logger.converted(parsedWholeTimes)
// Assert // Assert
assert parsedWholeTimes == EXPECTED_TIMES assertEquals(EXPECTED_TIMES, parsedWholeTimes)
} }
/** /**
@ -391,15 +391,10 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
final List UNITS = TimeUnit.values() as List final List UNITS = TimeUnit.values() as List
// Act // Act
def nullMsg = shouldFail(IllegalArgumentException) { IllegalArgumentException nullIae = assertThrows(IllegalArgumentException.class,
FormatUtils.getSmallerTimeUnit(null) () -> FormatUtils.getSmallerTimeUnit(null))
} IllegalArgumentException nanoIae = assertThrows(IllegalArgumentException.class,
logger.expected(nullMsg) () -> FormatUtils.getSmallerTimeUnit(TimeUnit.NANOSECONDS))
def nanosMsg = shouldFail(IllegalArgumentException) {
FormatUtils.getSmallerTimeUnit(TimeUnit.NANOSECONDS)
}
logger.expected(nanosMsg)
List smallerTimeUnits = UNITS[1..-1].collect { TimeUnit unit -> List smallerTimeUnits = UNITS[1..-1].collect { TimeUnit unit ->
FormatUtils.getSmallerTimeUnit(unit) FormatUtils.getSmallerTimeUnit(unit)
@ -407,9 +402,9 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(smallerTimeUnits) logger.converted(smallerTimeUnits)
// Assert // Assert
assert nullMsg == "Cannot determine a smaller time unit than 'null'" assertEquals("Cannot determine a smaller time unit than 'null'", nullIae.getMessage())
assert nanosMsg == "Cannot determine a smaller time unit than 'NANOSECONDS'" assertEquals("Cannot determine a smaller time unit than 'NANOSECONDS'", nanoIae.getMessage())
assert smallerTimeUnits == UNITS[0..<-1] assertEquals(smallerTimeUnits, UNITS.subList(0, UNITS.size() - 1))
} }
/** /**
@ -435,9 +430,8 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
logger.converted(results) logger.converted(results)
// Assert // Assert
results.every { String key, long value -> results.entrySet().forEach(entry ->
assert value == SCENARIOS[key].expectedMultiplier assertEquals(SCENARIOS.get(entry.getKey()).expectedMultiplier, entry.getValue()))
}
} }
/** /**
@ -446,26 +440,21 @@ class TestFormatUtilsGroovy extends GroovyTestCase {
@Test @Test
void testCalculateMultiplierShouldHandleIncorrectUnits() { void testCalculateMultiplierShouldHandleIncorrectUnits() {
// Arrange // Arrange
final Map SCENARIOS = [ final Map<String, Map<String, TimeUnit>> SCENARIOS = [
"allUnits" : [original: TimeUnit.NANOSECONDS, destination: TimeUnit.DAYS], "allUnits" : [original: TimeUnit.NANOSECONDS, destination: TimeUnit.DAYS],
"nanosToMicros": [original: TimeUnit.NANOSECONDS, destination: TimeUnit.MICROSECONDS], "nanosToMicros": [original: TimeUnit.NANOSECONDS, destination: TimeUnit.MICROSECONDS],
"hoursToDays" : [original: TimeUnit.HOURS, destination: TimeUnit.DAYS], "hoursToDays" : [original: TimeUnit.HOURS, destination: TimeUnit.DAYS],
] ]
// Act // Act
Map results = SCENARIOS.collectEntries { String k, Map values -> SCENARIOS.entrySet().stream()
logger.debug("Evaluating ${k}: ${values}") .forEach(entry -> {
def msg = shouldFail(IllegalArgumentException) { // Assert
FormatUtils.calculateMultiplier(values.original, values.destination) IllegalArgumentException iae = assertThrows(IllegalArgumentException.class,
} () -> FormatUtils.calculateMultiplier(entry.getValue().get("original"),
logger.expected(msg) entry.getValue().get("destination")))
[k, msg] assertTrue((iae.getMessage() =~ "The original time unit '.*' must be larger than the new time unit '.*'").find())
} })
// Assert
results.every { String key, String value ->
assert value =~ "The original time unit '.*' must be larger than the new time unit '.*'"
}
} }
// TODO: Microsecond parsing // TODO: Microsecond parsing

View File

@ -22,7 +22,7 @@ import org.junit.jupiter.api.Test;
import java.text.DecimalFormatSymbols; import java.text.DecimalFormatSymbols;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestFormatUtils { public class TestFormatUtils {
@ -38,7 +38,7 @@ public class TestFormatUtils {
} }
@Test @Test
public void testFormatTime() throws Exception { public void testFormatTime() {
assertEquals("00:00:00.000", FormatUtils.formatHoursMinutesSeconds(0, TimeUnit.DAYS)); assertEquals("00:00:00.000", FormatUtils.formatHoursMinutesSeconds(0, TimeUnit.DAYS));
assertEquals("01:00:00.000", FormatUtils.formatHoursMinutesSeconds(1, TimeUnit.HOURS)); assertEquals("01:00:00.000", FormatUtils.formatHoursMinutesSeconds(1, TimeUnit.HOURS));
assertEquals("02:00:00.000", FormatUtils.formatHoursMinutesSeconds(2, TimeUnit.HOURS)); assertEquals("02:00:00.000", FormatUtils.formatHoursMinutesSeconds(2, TimeUnit.HOURS));

View File

@ -22,9 +22,9 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.Arrays; import java.nio.charset.StandardCharsets;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
public class TestCompressionInputOutputStreams { public class TestCompressionInputOutputStreams {
@ -32,7 +32,7 @@ public class TestCompressionInputOutputStreams {
public void testSimple() throws IOException { public void testSimple() throws IOException {
final ByteArrayOutputStream baos = new ByteArrayOutputStream(); final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final byte[] data = "Hello, World!".getBytes("UTF-8"); final byte[] data = "Hello, World!".getBytes(StandardCharsets.UTF_8);
final CompressionOutputStream cos = new CompressionOutputStream(baos); final CompressionOutputStream cos = new CompressionOutputStream(baos);
cos.write(data); cos.write(data);
@ -43,7 +43,7 @@ public class TestCompressionInputOutputStreams {
final CompressionInputStream cis = new CompressionInputStream(new ByteArrayInputStream(compressedBytes)); final CompressionInputStream cis = new CompressionInputStream(new ByteArrayInputStream(compressedBytes));
final byte[] decompressed = readFully(cis); final byte[] decompressed = readFully(cis);
assertTrue(Arrays.equals(data, decompressed)); assertArrayEquals(data, decompressed);
} }
@Test @Test
@ -54,7 +54,7 @@ public class TestCompressionInputOutputStreams {
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
sb.append(str); sb.append(str);
} }
final byte[] data = sb.toString().getBytes("UTF-8"); final byte[] data = sb.toString().getBytes(StandardCharsets.UTF_8);
final ByteArrayOutputStream baos = new ByteArrayOutputStream(); final ByteArrayOutputStream baos = new ByteArrayOutputStream();
@ -67,13 +67,13 @@ public class TestCompressionInputOutputStreams {
final CompressionInputStream cis = new CompressionInputStream(new ByteArrayInputStream(compressedBytes)); final CompressionInputStream cis = new CompressionInputStream(new ByteArrayInputStream(compressedBytes));
final byte[] decompressed = readFully(cis); final byte[] decompressed = readFully(cis);
assertTrue(Arrays.equals(data, decompressed)); assertArrayEquals(data, decompressed);
} }
@Test @Test
public void testDataLargerThanBufferWhileFlushing() throws IOException { public void testDataLargerThanBufferWhileFlushing() throws IOException {
final String str = "The quick brown fox jumps over the lazy dog\r\n\n\n\r"; final String str = "The quick brown fox jumps over the lazy dog\r\n\n\n\r";
final byte[] data = str.getBytes("UTF-8"); final byte[] data = str.getBytes(StandardCharsets.UTF_8);
final StringBuilder sb = new StringBuilder(); final StringBuilder sb = new StringBuilder();
final byte[] data1024; final byte[] data1024;
@ -87,19 +87,19 @@ public class TestCompressionInputOutputStreams {
sb.append(str); sb.append(str);
} }
cos.close(); cos.close();
data1024 = sb.toString().getBytes("UTF-8"); data1024 = sb.toString().getBytes(StandardCharsets.UTF_8);
final byte[] compressedBytes = baos.toByteArray(); final byte[] compressedBytes = baos.toByteArray();
final CompressionInputStream cis = new CompressionInputStream(new ByteArrayInputStream(compressedBytes)); final CompressionInputStream cis = new CompressionInputStream(new ByteArrayInputStream(compressedBytes));
final byte[] decompressed = readFully(cis); final byte[] decompressed = readFully(cis);
assertTrue(Arrays.equals(data1024, decompressed)); assertArrayEquals(data1024, decompressed);
} }
@Test @Test
public void testSendingMultipleFilesBackToBackOnSameStream() throws IOException { public void testSendingMultipleFilesBackToBackOnSameStream() throws IOException {
final String str = "The quick brown fox jumps over the lazy dog\r\n\n\n\r"; final String str = "The quick brown fox jumps over the lazy dog\r\n\n\n\r";
final byte[] data = str.getBytes("UTF-8"); final byte[] data = str.getBytes(StandardCharsets.UTF_8);
final ByteArrayOutputStream baos = new ByteArrayOutputStream(); final ByteArrayOutputStream baos = new ByteArrayOutputStream();
@ -122,18 +122,18 @@ public class TestCompressionInputOutputStreams {
for (int i = 0; i < 512; i++) { for (int i = 0; i < 512; i++) {
sb.append(str); sb.append(str);
} }
data512 = sb.toString().getBytes("UTF-8"); data512 = sb.toString().getBytes(StandardCharsets.UTF_8);
final byte[] compressedBytes = baos.toByteArray(); final byte[] compressedBytes = baos.toByteArray();
final ByteArrayInputStream bais = new ByteArrayInputStream(compressedBytes); final ByteArrayInputStream bais = new ByteArrayInputStream(compressedBytes);
final CompressionInputStream cis = new CompressionInputStream(bais); final CompressionInputStream cis = new CompressionInputStream(bais);
final byte[] decompressed = readFully(cis); final byte[] decompressed = readFully(cis);
assertTrue(Arrays.equals(data512, decompressed)); assertArrayEquals(data512, decompressed);
final CompressionInputStream cis2 = new CompressionInputStream(bais); final CompressionInputStream cis2 = new CompressionInputStream(bais);
final byte[] decompressed2 = readFully(cis2); final byte[] decompressed2 = readFully(cis2);
assertTrue(Arrays.equals(data512, decompressed2)); assertArrayEquals(data512, decompressed2);
} }
private byte[] readFully(final InputStream in) throws IOException { private byte[] readFully(final InputStream in) throws IOException {

View File

@ -32,7 +32,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertLinesMatch;
public class TestLineDemarcator { public class TestLineDemarcator {
@ -41,14 +41,14 @@ public class TestLineDemarcator {
final String input = "A\nB\nC\rD\r\nE\r\nF\r\rG"; final String input = "A\nB\nC\rD\r\nE\r\nF\r\rG";
final List<String> lines = getLines(input); final List<String> lines = getLines(input);
assertEquals(Arrays.asList("A\n", "B\n", "C\r", "D\r\n", "E\r\n", "F\r", "\r", "G"), lines); assertLinesMatch(Arrays.asList("A\n", "B\n", "C\r", "D\r\n", "E\r\n", "F\r", "\r", "G"), lines);
} }
@Test @Test
public void testEmptyStream() throws IOException { public void testEmptyStream() throws IOException {
final List<String> lines = getLines(""); final List<String> lines = getLines("");
assertEquals(Collections.emptyList(), lines); assertLinesMatch(Collections.emptyList(), lines);
} }
@Test @Test
@ -56,7 +56,7 @@ public class TestLineDemarcator {
final String input = "\r\r\r\n\n\n\r\n"; final String input = "\r\r\r\n\n\n\r\n";
final List<String> lines = getLines(input); final List<String> lines = getLines(input);
assertEquals(Arrays.asList("\r", "\r", "\r\n", "\n", "\n", "\r\n"), lines); assertLinesMatch(Arrays.asList("\r", "\r", "\r\n", "\n", "\n", "\r\n"), lines);
} }
@Test @Test
@ -64,25 +64,25 @@ public class TestLineDemarcator {
final String input = "ABC\r\nXYZ"; final String input = "ABC\r\nXYZ";
final List<String> lines = getLines(input, 10, 4); final List<String> lines = getLines(input, 10, 4);
assertEquals(Arrays.asList("ABC\r\n", "XYZ"), lines); assertLinesMatch(Arrays.asList("ABC\r\n", "XYZ"), lines);
} }
@Test @Test
public void testEndsWithCarriageReturn() throws IOException { public void testEndsWithCarriageReturn() throws IOException {
final List<String> lines = getLines("ABC\r"); final List<String> lines = getLines("ABC\r");
assertEquals(Arrays.asList("ABC\r"), lines); assertLinesMatch(Arrays.asList("ABC\r"), lines);
} }
@Test @Test
public void testEndsWithNewLine() throws IOException { public void testEndsWithNewLine() throws IOException {
final List<String> lines = getLines("ABC\n"); final List<String> lines = getLines("ABC\n");
assertEquals(Arrays.asList("ABC\n"), lines); assertLinesMatch(Arrays.asList("ABC\n"), lines);
} }
@Test @Test
public void testEndsWithCarriageReturnNewLine() throws IOException { public void testEndsWithCarriageReturnNewLine() throws IOException {
final List<String> lines = getLines("ABC\r\n"); final List<String> lines = getLines("ABC\r\n");
assertEquals(Arrays.asList("ABC\r\n"), lines); assertLinesMatch(Arrays.asList("ABC\r\n"), lines);
} }
@Test @Test
@ -90,13 +90,13 @@ public class TestLineDemarcator {
final String input = "he\ra-to-a\rb-to-b\rc-to-c\r\nd-to-d"; final String input = "he\ra-to-a\rb-to-b\rc-to-c\r\nd-to-d";
final List<String> lines = getLines(input, 10, 10); final List<String> lines = getLines(input, 10, 10);
assertEquals(Arrays.asList("he\r", "a-to-a\r", "b-to-b\r", "c-to-c\r\n", "d-to-d"), lines); assertLinesMatch(Arrays.asList("he\r", "a-to-a\r", "b-to-b\r", "c-to-c\r\n", "d-to-d"), lines);
} }
@Test @Test
public void testFirstCharMatchOnly() throws IOException { public void testFirstCharMatchOnly() throws IOException {
final List<String> lines = getLines("\nThe quick brown fox jumped over the lazy dog."); final List<String> lines = getLines("\nThe quick brown fox jumped over the lazy dog.");
assertEquals(Arrays.asList("\n", "The quick brown fox jumped over the lazy dog."), lines); assertLinesMatch(Arrays.asList("\n", "The quick brown fox jumped over the lazy dog."), lines);
} }
@Test @Test

View File

@ -24,11 +24,11 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;

View File

@ -18,7 +18,7 @@ package org.apache.nifi.util;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class StringSelectorTest { public class StringSelectorTest {
@Test @Test

View File

@ -20,8 +20,8 @@ import org.junit.jupiter.api.Test;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
public class TestTimedBuffer { public class TestTimedBuffer {
@ -90,7 +90,7 @@ public class TestTimedBuffer {
return oldValue; return oldValue;
} }
return new TimestampedLong(oldValue.getValue().longValue() + toAdd.getValue().longValue()); return new TimestampedLong(oldValue.getValue() + toAdd.getValue());
} }
@Override @Override

View File

@ -22,6 +22,8 @@ import javax.servlet.FilterConfig
import javax.servlet.ServletContext import javax.servlet.ServletContext
import javax.servlet.http.HttpServletRequest import javax.servlet.http.HttpServletRequest
import static org.junit.jupiter.api.Assertions.assertEquals
class SanitizeContextPathFilterTest { class SanitizeContextPathFilterTest {
private static String getValue(String parameterName, Map<String, String> params = [:]) { private static String getValue(String parameterName, Map<String, String> params = [:]) {
@ -45,7 +47,7 @@ class SanitizeContextPathFilterTest {
scpf.init(mockFilterConfig) scpf.init(mockFilterConfig)
// Assert // Assert
assert scpf.getAllowedContextPaths() == EXPECTED_ALLOWED_CONTEXT_PATHS assertEquals(EXPECTED_ALLOWED_CONTEXT_PATHS, scpf.getAllowedContextPaths())
} }
@Test @Test
@ -64,7 +66,7 @@ class SanitizeContextPathFilterTest {
scpf.init(mockFilterConfig) scpf.init(mockFilterConfig)
// Assert // Assert
assert scpf.getAllowedContextPaths() == EXPECTED_ALLOWED_CONTEXT_PATHS assertEquals(EXPECTED_ALLOWED_CONTEXT_PATHS, scpf.getAllowedContextPaths())
} }
@Test @Test
@ -113,6 +115,6 @@ class SanitizeContextPathFilterTest {
scpf.injectContextPathAttribute(mockRequest) scpf.injectContextPathAttribute(mockRequest)
// Assert // Assert
assert requestAttributes["contextPath"] == EXPECTED_CONTEXT_PATH assertEquals(EXPECTED_CONTEXT_PATH, requestAttributes["contextPath"])
} }
} }