diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java index 28f2a3b5e3..3df3c89388 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java @@ -952,7 +952,7 @@ public class PutDatabaseRecord extends AbstractProcessor { } } - private String generateTableName(final DMLSettings settings, final String catalog, final String schemaName, final String tableName, final TableSchema tableSchema) { + String generateTableName(final DMLSettings settings, final String catalog, final String schemaName, final String tableName, final TableSchema tableSchema) { final StringBuilder tableNameBuilder = new StringBuilder(); if (catalog != null) { if (settings.quoteTableName) { @@ -1399,7 +1399,7 @@ public class PutDatabaseRecord extends AbstractProcessor { private Map columns; private String quotedIdentifierString; - private TableSchema(final List columnDescriptions, final boolean translateColumnNames, + TableSchema(final List columnDescriptions, final boolean translateColumnNames, final Set primaryKeyColumnNames, final String quotedIdentifierString) { this.columns = new LinkedHashMap<>(); this.primaryKeyColumnNames = primaryKeyColumnNames; @@ -1688,7 +1688,7 @@ public class PutDatabaseRecord extends AbstractProcessor { // Quote table name? private final boolean quoteTableName; - private DMLSettings(ProcessContext context) { + DMLSettings(ProcessContext context) { translateFieldNames = context.getProperty(TRANSLATE_FIELD_NAMES).asBoolean(); ignoreUnmappedFields = IGNORE_UNMATCHED_FIELD.getValue().equalsIgnoreCase(context.getProperty(UNMATCHED_FIELD_BEHAVIOR).getValue()); diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/CountTextTest.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/CountTextTest.groovy deleted file mode 100644 index e9fa0558e5..0000000000 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/CountTextTest.groovy +++ /dev/null @@ -1,396 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License") you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.standard - -import org.apache.nifi.components.PropertyDescriptor -import org.apache.nifi.flowfile.FlowFile -import org.apache.nifi.util.MockComponentLog -import org.apache.nifi.util.MockProcessSession -import org.apache.nifi.util.TestRunner -import org.apache.nifi.util.TestRunners -import org.bouncycastle.jce.provider.BouncyCastleProvider -import org.junit.After -import org.junit.Before -import org.junit.BeforeClass -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.mockito.Mockito -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -import java.nio.charset.StandardCharsets -import java.security.Security - -import static org.mockito.ArgumentMatchers.anyBoolean -import static org.mockito.ArgumentMatchers.anyString -import static org.mockito.Mockito.doReturn -import static org.mockito.Mockito.spy -import static org.mockito.Mockito.when - -@RunWith(JUnit4.class) -class CountTextTest extends GroovyTestCase { - private static final Logger logger = LoggerFactory.getLogger(CountTextTest.class) - - private static final String TLC = "text.line.count" - private static final String TLNEC = "text.line.nonempty.count" - private static final String TWC = "text.word.count" - private static final String TCC = "text.character.count" - - - @BeforeClass - static void setUpOnce() throws Exception { - Security.addProvider(new BouncyCastleProvider()) - - logger.metaClass.methodMissing = { String name, args -> - logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") - } - } - - @Before - void setUp() throws Exception { - } - - @After - void tearDown() throws Exception { - } - - @Test - void testShouldCountAllMetrics() throws Exception { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(CountText.class) - - runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "true") - - // This text is the same as in src/test/resources/TestCountText/jabberwocky.txt but is copied here - // to ensure that reading from a file vs. static text doesn't cause line break issues - String INPUT_TEXT = """’Twas brillig, and the slithy toves -Did gyre and gimble in the wade; -All mimsy were the borogoves, -And the mome raths outgrabe. - -"Beware the Jabberwock, my son! -The jaws that bite, the claws that catch! -Beware the Jubjub bird, and shun -The frumious Bandersnatch!" - -He took his vorpal sword in hand: -Long time the manxome foe he sought— -So rested he by the Tumtum tree, -And stood awhile in thought. - -And as in uffish thought he stood, -The Jabberwock, with eyes of flame, -Came whiffling through the tulgey wood. -And burbled as it came! - -One, two! One, two! And through and through -The vorpal blade went snicker-snack! -He left it dead, and with its head -He went galumphing back. - -"And hast thou slain the Jabberwock? -Come to my arms, my beamish boy! -O frabjous day! Callooh! Callay!" -He chortled in his joy. - -’Twas brillig, and the slithy toves -Did gyre and gimble in the wabe; -All mimsy were the borogoves, -And the mome raths outgrabe.""" - - runner.enqueue(INPUT_TEXT.bytes) - - // Act - runner.run() - - // Assert - runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1) - FlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).first() - assert flowFile.attributes."$TLC" == 34 as String - assert flowFile.attributes."$TLNEC" == 28 as String - assert flowFile.attributes."$TWC" == 166 as String - assert flowFile.attributes."$TCC" == 900 as String - } - - @Test - void testShouldCountEachMetric() throws Exception { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(CountText.class) - String INPUT_TEXT = new File("src/test/resources/TestCountText/jabberwocky.txt").text - - final def EXPECTED_VALUES = [ - (TLC) : 34, - (TLNEC): 28, - (TWC) : 166, - (TCC) : 900, - ] - - def linesOnly = [(CountText.TEXT_LINE_COUNT_PD): "true"] - def linesNonEmptyOnly = [(CountText.TEXT_LINE_NONEMPTY_COUNT_PD): "true"] - def wordsOnly = [(CountText.TEXT_WORD_COUNT_PD): "true"] - def charactersOnly = [(CountText.TEXT_CHARACTER_COUNT_PD): "true"] - - final List> SCENARIOS = [linesOnly, linesNonEmptyOnly, wordsOnly, charactersOnly] - - SCENARIOS.each { map -> - // Reset the processor properties - runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "false") - runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "false") - runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "false") - runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "false") - - // Apply the scenario-specific properties - map.each { key, value -> - runner.setProperty(key, value) - } - - runner.clearProvenanceEvents() - runner.clearTransferState() - runner.enqueue(INPUT_TEXT.bytes) - - // Act - runner.run() - - // Assert - runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1) - FlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).first() - logger.info("Generated flowfile: ${flowFile} | ${flowFile.attributes}") - EXPECTED_VALUES.each { key, value -> - if (flowFile.attributes.containsKey(key)) { - assert flowFile.attributes.get(key) == value as String - } - } - } - } - - @Test - void testShouldCountWordsSplitOnSymbol() throws Exception { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(CountText.class) - String INPUT_TEXT = new File("src/test/resources/TestCountText/jabberwocky.txt").text - - final int EXPECTED_WORD_COUNT = 167 - - // Reset the processor properties - runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "false") - runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "false") - runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "false") - runner.setProperty(CountText.SPLIT_WORDS_ON_SYMBOLS_PD, "true") - - runner.clearProvenanceEvents() - runner.clearTransferState() - runner.enqueue(INPUT_TEXT.bytes) - - // Act - runner.run() - - // Assert - runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1) - FlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).first() - logger.info("Generated flowfile: ${flowFile} | ${flowFile.attributes}") - assert flowFile.attributes.get(CountText.TEXT_WORD_COUNT) == EXPECTED_WORD_COUNT as String - } - - @Test - void testShouldCountIndependentlyPerFlowFile() throws Exception { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(CountText.class) - String INPUT_TEXT = new File("src/test/resources/TestCountText/jabberwocky.txt").text - - final def EXPECTED_VALUES = [ - (TLC) : 34, - (TLNEC): 28, - (TWC) : 166, - (TCC) : 900, - ] - - // Reset the processor properties - runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "true") - - 2.times { int i -> - runner.clearProvenanceEvents() - runner.clearTransferState() - runner.enqueue(INPUT_TEXT.bytes) - - // Act - runner.run() - - // Assert - runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1) - FlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).first() - logger.info("Generated flowfile: ${flowFile} | ${flowFile.attributes}") - EXPECTED_VALUES.each { key, value -> - if (flowFile.attributes.containsKey(key)) { - assert flowFile.attributes.get(key) == value as String - } - } - } - } - - @Test - void testShouldTrackSessionCountersAcrossMultipleFlowfiles() throws Exception { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(CountText.class) - String INPUT_TEXT = new File("src/test/resources/TestCountText/jabberwocky.txt").text - - final def EXPECTED_VALUES = [ - (TLC) : 34, - (TLNEC): 28, - (TWC) : 166, - (TCC) : 900, - ] - - // Reset the processor properties - runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "true") - - MockProcessSession mockPS = runner.processSessionFactory.createSession() - def sessionCounters = mockPS.sharedState.counterMap - logger.info("Session counters (0): ${sessionCounters}") - - int n = 2 - - n.times { int i -> - runner.clearTransferState() - runner.enqueue(INPUT_TEXT.bytes) - - // Act - runner.run() - logger.info("Session counters (${i + 1}): ${sessionCounters}") - - // Assert - runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1) - FlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).first() - logger.info("Generated flowfile: ${flowFile} | ${flowFile.attributes}") - EXPECTED_VALUES.each { key, value -> - if (flowFile.attributes.containsKey(key)) { - assert flowFile.attributes.get(key) == value as String - } - } - } - - assert sessionCounters.get("Lines Counted").get() == EXPECTED_VALUES[TLC] * n as long - assert sessionCounters.get("Lines (non-empty) Counted").get() == EXPECTED_VALUES[TLNEC] * n as long - assert sessionCounters.get("Words Counted").get() == EXPECTED_VALUES[TWC] * n as long - assert sessionCounters.get("Characters Counted").get() == EXPECTED_VALUES[TCC] * n as long - } - - @Test - void testShouldHandleInternalError() throws Exception { - // Arrange - CountText ct = new CountText() - ct.countLines = true - ct.countLinesNonEmpty = true - ct.countWords = true - ct.countCharacters = true - - CountText ctSpy = Mockito.spy(ct) - when(ctSpy.countWordsInLine(anyString(), anyBoolean())).thenThrow(new IOException("Expected exception")) - - final TestRunner runner = TestRunners.newTestRunner(ctSpy) - final String INPUT_TEXT = "This flowfile should throw an error" - - // Reset the processor properties - runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "true") - runner.setProperty(CountText.CHARACTER_ENCODING_PD, StandardCharsets.US_ASCII.displayName()) - - runner.enqueue(INPUT_TEXT.bytes) - - // Act - // Need initialize = true to run #onScheduled() - runner.run(1, true, true) - - // Assert - runner.assertAllFlowFilesTransferred(CountText.REL_FAILURE, 1) - FlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_FAILURE).first() - logger.info("Generated flowfile: ${flowFile} | ${flowFile.attributes}") - } - - @Test - void testShouldIgnoreWhitespaceWordsWhenCounting() throws Exception { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(CountText.class) - String INPUT_TEXT = "a b c" - - final int EXPECTED_WORD_COUNT = 3 - - // Reset the processor properties - runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "false") - runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "false") - runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "false") - runner.setProperty(CountText.SPLIT_WORDS_ON_SYMBOLS_PD, "true") - - runner.clearProvenanceEvents() - runner.clearTransferState() - runner.enqueue(INPUT_TEXT.bytes) - - // Act - runner.run() - - // Assert - runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1) - FlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).first() - logger.info("Generated flowfile: ${flowFile} | ${flowFile.attributes}") - assert flowFile.attributes.get(CountText.TEXT_WORD_COUNT) == EXPECTED_WORD_COUNT as String - } - - @Test - void testShouldIgnoreWhitespaceWordsWhenCountingDebugMode() throws Exception { - // Arrange - MockComponentLog componentLogger = spy(new MockComponentLog("processorId", new CountText())) - doReturn(true).when(componentLogger).isDebugEnabled() - final TestRunner runner = TestRunners.newTestRunner(CountText.class, componentLogger) - String INPUT_TEXT = "a b c" - - final int EXPECTED_WORD_COUNT = 3 - - // Reset the processor properties - runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "false") - runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "false") - runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true") - runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "false") - runner.setProperty(CountText.SPLIT_WORDS_ON_SYMBOLS_PD, "true") - - runner.clearProvenanceEvents() - runner.clearTransferState() - runner.enqueue(INPUT_TEXT.bytes) - - // Act - runner.run() - - // Assert - runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1) - FlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).first() - logger.info("Generated flowfile: ${flowFile} | ${flowFile.attributes}") - assert flowFile.attributes.get(CountText.TEXT_WORD_COUNT) == EXPECTED_WORD_COUNT as String - } - -} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/CryptographicHashAttributeTest.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/CryptographicHashAttributeTest.groovy deleted file mode 100644 index d9f8a3fef5..0000000000 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/CryptographicHashAttributeTest.groovy +++ /dev/null @@ -1,416 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License") you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.standard - - -import org.apache.nifi.security.util.crypto.HashAlgorithm -import org.apache.nifi.security.util.crypto.HashService -import org.apache.nifi.util.MockFlowFile -import org.apache.nifi.util.TestRunner -import org.apache.nifi.util.TestRunners -import org.bouncycastle.jce.provider.BouncyCastleProvider -import org.junit.After -import org.junit.Before -import org.junit.BeforeClass -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -import java.nio.charset.Charset -import java.nio.charset.StandardCharsets -import java.security.Security -import java.time.ZonedDateTime -import java.time.format.DateTimeFormatter - -@RunWith(JUnit4.class) -class CryptographicHashAttributeTest extends GroovyTestCase { - private static final Logger logger = LoggerFactory.getLogger(CryptographicHashAttributeTest.class) - - - @BeforeClass - static void setUpOnce() throws Exception { - Security.addProvider(new BouncyCastleProvider()) - - logger.metaClass.methodMissing = { String name, args -> - logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") - } - } - - @Before - void setUp() throws Exception { - } - - @After - void tearDown() throws Exception { - } - - @Test - void testShouldCalculateHashOfPresentAttribute() { - // Arrange - def algorithms = HashAlgorithm.values() - - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashAttribute()) - - // Create attributes for username and date - def attributes = [ - username: "alopresto", - // FIXME groovy-datetime dependency not providing the "format" method for new Date().format("YYYY-MM-dd HH:mm:ss.SSS Z") - // adding the following workaround temporarily - date : ZonedDateTime.now().format(DateTimeFormatter.ofPattern('YYYY-MM-dd HH:mm:ss.SSS Z')) - ] - def attributeKeys = attributes.keySet() - - algorithms.each { HashAlgorithm algorithm -> - final EXPECTED_USERNAME_HASH = HashService.hashValue(algorithm, attributes["username"]) - logger.expected("${algorithm.name.padLeft(11)}(${attributes["username"]}) = ${EXPECTED_USERNAME_HASH}") - final EXPECTED_DATE_HASH = HashService.hashValue(algorithm, attributes["date"]) - logger.expected("${algorithm.name.padLeft(11)}(${attributes["date"]}) = ${EXPECTED_DATE_HASH}") - - // Reset the processor - runner.clearProperties() - runner.clearProvenanceEvents() - runner.clearTransferState() - - // Set the algorithm - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.name) - - // Add the desired dynamic properties - attributeKeys.each { String attr -> - runner.setProperty(attr, "${attr}_${algorithm.name}") - } - - // Insert the attributes in the mock flowfile - runner.enqueue(new byte[0], attributes) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 0) - runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 1) - - final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_SUCCESS) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = successfulFlowfiles.first() - String hashedUsername = flowFile.getAttribute("username_${algorithm.name}") - logger.info("flowfile.username_${algorithm.name} = ${hashedUsername}") - String hashedDate = flowFile.getAttribute("date_${algorithm.name}") - logger.info("flowfile.date_${algorithm.name} = ${hashedDate}") - - assert hashedUsername == EXPECTED_USERNAME_HASH - assert hashedDate == EXPECTED_DATE_HASH - } - } - - @Test - void testShouldCalculateHashOfMissingAttribute() { - // Arrange - def algorithms = HashAlgorithm.values() - - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashAttribute()) - - // Create attributes for username (empty string) and date (null) - def attributes = [ - username: "", - date : null - ] - def attributeKeys = attributes.keySet() - - algorithms.each { HashAlgorithm algorithm -> - final EXPECTED_USERNAME_HASH = HashService.hashValue(algorithm, attributes["username"]) - logger.expected("${algorithm.name.padLeft(11)}(${attributes["username"]}) = ${EXPECTED_USERNAME_HASH}") - final EXPECTED_DATE_HASH = null - logger.expected("${algorithm.name.padLeft(11)}(${attributes["date"]}) = ${EXPECTED_DATE_HASH}") - - // Reset the processor - runner.clearProperties() - runner.clearProvenanceEvents() - runner.clearTransferState() - - // Set the algorithm - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.name) - - // Add the desired dynamic properties - attributeKeys.each { String attr -> - runner.setProperty(attr, "${attr}_${algorithm.name}") - } - - // Insert the attributes in the mock flowfile - runner.enqueue(new byte[0], attributes) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 0) - runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 1) - - final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_SUCCESS) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = successfulFlowfiles.first() - String hashedUsername = flowFile.getAttribute("username_${algorithm.name}") - logger.info("flowfile.username_${algorithm.name} = ${hashedUsername}") - String hashedDate = flowFile.getAttribute("date_${algorithm.name}") - logger.info("flowfile.date_${algorithm.name} = ${hashedDate}") - - assert hashedUsername == EXPECTED_USERNAME_HASH - assert hashedDate == EXPECTED_DATE_HASH - } - } - - @Test - void testShouldRouteToFailureOnProhibitedMissingAttribute() { - // Arrange - def algorithms = HashAlgorithm.values() - - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashAttribute()) - - // Create attributes for username (empty string) and date (null) - def attributes = [ - username: "", - date : null - ] - def attributeKeys = attributes.keySet() - - algorithms.each { HashAlgorithm algorithm -> - final EXPECTED_USERNAME_HASH = HashService.hashValue(algorithm, attributes["username"]) - logger.expected("${algorithm.name.padLeft(11)}(${attributes["username"]}) = ${EXPECTED_USERNAME_HASH}") - final EXPECTED_DATE_HASH = null - logger.expected("${algorithm.name.padLeft(11)}(${attributes["date"]}) = ${EXPECTED_DATE_HASH}") - - // Reset the processor - runner.clearProperties() - runner.clearProvenanceEvents() - runner.clearTransferState() - - // Set the algorithm - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.name) - - // Set to fail if there are missing attributes - runner.setProperty(CryptographicHashAttribute.PARTIAL_ATTR_ROUTE_POLICY, CryptographicHashAttribute.PartialAttributePolicy.PROHIBIT.name()) - - // Add the desired dynamic properties - attributeKeys.each { String attr -> - runner.setProperty(attr, "${attr}_${algorithm.name}") - } - - // Insert the attributes in the mock flowfile - runner.enqueue(new byte[0], attributes) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 1) - runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 0) - - final List failedFlowFiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_FAILURE) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = failedFlowFiles.first() - logger.info("Failed flowfile has attributes ${flowFile.attributes}") - attributeKeys.each { String missingAttribute -> - flowFile.assertAttributeNotExists("${missingAttribute}_${algorithm.name}") - } - } - } - - @Test - void testShouldRouteToFailureOnEmptyAttributes() { - // Arrange - def algorithms = HashAlgorithm.values() - - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashAttribute()) - - // Create attributes for username (empty string) and date (null) - def attributes = [ - username: "", - date : null - ] - def attributeKeys = attributes.keySet() - - algorithms.each { HashAlgorithm algorithm -> - // Reset the processor - runner.clearProperties() - runner.clearProvenanceEvents() - runner.clearTransferState() - - // Set the algorithm - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.name) - - // Set to fail if all attributes are missing - runner.setProperty(CryptographicHashAttribute.FAIL_WHEN_EMPTY, "true") - - // Insert the attributes in the mock flowfile - runner.enqueue(new byte[0], attributes) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 1) - runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 0) - - final List failedFlowFiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_FAILURE) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = failedFlowFiles.first() - logger.info("Failed flowfile has attributes ${flowFile.attributes}") - attributeKeys.each { String missingAttribute -> - flowFile.assertAttributeNotExists("${missingAttribute}_${algorithm.name}") - } - } - } - - @Test - void testShouldRouteToSuccessOnAllowPartial() { - // Arrange - def algorithms = HashAlgorithm.values() - - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashAttribute()) - - // Create attributes for username (empty string) and date (null) - def attributes = [ - username: "" - ] - def attributeKeys = attributes.keySet() - - algorithms.each { HashAlgorithm algorithm -> - final EXPECTED_USERNAME_HASH = HashService.hashValue(algorithm, attributes["username"]) - logger.expected("${algorithm.name.padLeft(11)}(${attributes["username"]}) = ${EXPECTED_USERNAME_HASH}") - final EXPECTED_DATE_HASH = null - logger.expected("${algorithm.name.padLeft(11)}(${attributes["date"]}) = ${EXPECTED_DATE_HASH}") - - // Reset the processor - runner.clearProperties() - runner.clearProvenanceEvents() - runner.clearTransferState() - - // Set the algorithm - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.name) - - // Set to fail if there are missing attributes - runner.setProperty(CryptographicHashAttribute.PARTIAL_ATTR_ROUTE_POLICY, CryptographicHashAttribute.PartialAttributePolicy.ALLOW.name()) - - // Add the desired dynamic properties - attributeKeys.each { String attr -> - runner.setProperty(attr, "${attr}_${algorithm.name}") - } - - // Insert the attributes in the mock flowfile - runner.enqueue(new byte[0], attributes) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 0) - runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 1) - - final List successfulFlowFiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_SUCCESS) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = successfulFlowFiles.first() - logger.info("Successful flowfile has attributes ${flowFile.attributes}") - attributeKeys.each { String attribute -> - flowFile.assertAttributeExists("${attribute}_${algorithm.name}") - } - } - } - - @Test - void testShouldCalculateHashWithVariousCharacterEncodings() { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashAttribute()) - - // Create attributes - def attributes = [test_attribute: "apachenifi"] - def attributeKeys = attributes.keySet() - - HashAlgorithm algorithm = HashAlgorithm.MD5 - - List charsets = [StandardCharsets.UTF_8, StandardCharsets.UTF_16, StandardCharsets.UTF_16LE, StandardCharsets.UTF_16BE] - - final def EXPECTED_MD5_HASHES = [ - "utf_8" : "a968b5ec1d52449963dcc517789baaaf", - "utf_16" : "b8413d18f7e64042bb0322a1cd61eba2", - "utf_16be": "b8413d18f7e64042bb0322a1cd61eba2", - "utf_16le": "91c3b67f9f8ae77156f21f271cc09121", - ] - EXPECTED_MD5_HASHES.each { k, hash -> - logger.expected("MD5(${k.padLeft(9)}(${attributes["test_attribute"]})) = ${hash}") - } - - charsets.each { Charset charset -> - // Calculate the expected hash value given the character set - final EXPECTED_HASH = HashService.hashValue(algorithm, attributes["test_attribute"], charset) - logger.expected("${algorithm.name}(${attributes["test_attribute"]}, ${charset.name()}) = ${EXPECTED_HASH}") - - // Sanity check - assert EXPECTED_HASH == EXPECTED_MD5_HASHES[translateEncodingToMapKey(charset.name())] - - // Reset the processor - runner.clearProperties() - runner.clearProvenanceEvents() - runner.clearTransferState() - - // Set the properties - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.name) - - logger.info("Setting character set to ${charset.name()}") - runner.setProperty(CryptographicHashAttribute.CHARACTER_SET, charset.name()) - - // Add the desired dynamic properties - attributeKeys.each { String attr -> - runner.setProperty(attr, "${attr}_${algorithm.name}") - } - - // Insert the attributes in the mock flowfile - runner.enqueue(new byte[0], attributes) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 0) - runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 1) - - final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_SUCCESS) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = successfulFlowfiles.first() - String hashedAttribute = flowFile.getAttribute("test_attribute_${algorithm.name}") - logger.info("flowfile.test_attribute_${algorithm.name} = ${hashedAttribute}") - - assert hashedAttribute == EXPECTED_HASH - } - } - - static String translateEncodingToMapKey(String charsetName) { - charsetName.toLowerCase().replaceAll(/[-\/]/, '_') - } -} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/CryptographicHashContentTest.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/CryptographicHashContentTest.groovy deleted file mode 100644 index ec25594648..0000000000 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/CryptographicHashContentTest.groovy +++ /dev/null @@ -1,289 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License") you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.standard - - -import org.apache.nifi.security.util.crypto.HashAlgorithm -import org.apache.nifi.security.util.crypto.HashService -import org.apache.nifi.util.MockFlowFile -import org.apache.nifi.util.TestRunner -import org.apache.nifi.util.TestRunners -import org.bouncycastle.jce.provider.BouncyCastleProvider -import org.junit.After -import org.junit.Before -import org.junit.BeforeClass -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -import java.nio.charset.StandardCharsets -import java.security.Security - -@RunWith(JUnit4.class) -class CryptographicHashContentTest extends GroovyTestCase { - private static final Logger logger = LoggerFactory.getLogger(CryptographicHashContentTest.class) - - @BeforeClass - static void setUpOnce() throws Exception { - Security.addProvider(new BouncyCastleProvider()) - - logger.metaClass.methodMissing = { String name, args -> - logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") - } - } - - @Before - void setUp() throws Exception { - } - - @After - void tearDown() throws Exception { - } - - @Test - void testShouldCalculateHashOfPresentContent() { - // Arrange - def algorithms = HashAlgorithm.values() - - // Generate some long content (90 KB) - final String LONG_CONTENT = "apachenifi " * 8192 - - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashContent()) - - algorithms.each { HashAlgorithm algorithm -> - final String EXPECTED_CONTENT_HASH = HashService.hashValueStreaming(algorithm, new ByteArrayInputStream(LONG_CONTENT.bytes)) - logger.info("Expected ${algorithm.name.padLeft(11)}: ${EXPECTED_CONTENT_HASH}") - - // Reset the processor - runner.clearProperties() - runner.clearProvenanceEvents() - runner.clearTransferState() - - // Set the algorithm - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashContent.HASH_ALGORITHM, algorithm.name) - - // Insert the content in the mock flowfile - runner.enqueue(LONG_CONTENT.getBytes(StandardCharsets.UTF_8), - [size: LONG_CONTENT.length() as String]) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashContent.REL_FAILURE, 0) - runner.assertTransferCount(CryptographicHashContent.REL_SUCCESS, 1) - - final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashContent.REL_SUCCESS) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = successfulFlowfiles.first() - String hashAttribute = "content_${algorithm.name}" - flowFile.assertAttributeExists(hashAttribute) - - String hashedContent = flowFile.getAttribute(hashAttribute) - logger.info("flowfile.${hashAttribute} = ${hashedContent}") - - assert hashedContent == EXPECTED_CONTENT_HASH - } - } - - @Test - void testShouldCalculateHashOfEmptyContent() { - // Arrange - def algorithms = HashAlgorithm.values() - - final String EMPTY_CONTENT = "" - - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashContent()) - - algorithms.each { HashAlgorithm algorithm -> - final String EXPECTED_CONTENT_HASH = HashService.hashValueStreaming(algorithm, new ByteArrayInputStream(EMPTY_CONTENT.bytes)) - logger.info("Expected ${algorithm.name.padLeft(11)}: ${EXPECTED_CONTENT_HASH}") - - // Reset the processor - runner.clearProperties() - runner.clearProvenanceEvents() - runner.clearTransferState() - - // Set the algorithm - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashContent.HASH_ALGORITHM, algorithm.name) - - // Insert the content in the mock flowfile - runner.enqueue(EMPTY_CONTENT.getBytes(StandardCharsets.UTF_8), [size: "0"]) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashContent.REL_FAILURE, 0) - runner.assertTransferCount(CryptographicHashContent.REL_SUCCESS, 1) - - final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashContent.REL_SUCCESS) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = successfulFlowfiles.first() - String hashAttribute = "content_${algorithm.name}" - flowFile.assertAttributeExists(hashAttribute) - - String hashedContent = flowFile.getAttribute(hashAttribute) - logger.info("flowfile.${hashAttribute} = ${hashedContent}") - - assert hashedContent == EXPECTED_CONTENT_HASH - } - } - - /** - * This test works because {@link MockFlowFile} uses the actual internal {@code data.size} for {@code getSize ( )}, while {@code StandardFlowFileRecord} uses a separate {@code size} field. May need to use {@code flowfile.getContentClaim ( ) .getLength ( )}. - */ - @Test - void testShouldCalculateHashOfContentWithIncorrectSizeAttribute() { - // Arrange - def algorithms = HashAlgorithm.values() - - final String NON_EMPTY_CONTENT = "apachenifi" - - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashContent()) - - algorithms.each { HashAlgorithm algorithm -> - final String EXPECTED_CONTENT_HASH = HashService.hashValueStreaming(algorithm, new ByteArrayInputStream(NON_EMPTY_CONTENT.bytes)) - logger.info("Expected ${algorithm.name.padLeft(11)}: ${EXPECTED_CONTENT_HASH}") - - // Reset the processor - runner.clearProperties() - runner.clearProvenanceEvents() - runner.clearTransferState() - - // Set the algorithm - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashContent.HASH_ALGORITHM, algorithm.name) - - // Insert the content in the mock flowfile (with the wrong size attribute) - runner.enqueue(NON_EMPTY_CONTENT.getBytes(StandardCharsets.UTF_8), [size: "0"]) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashContent.REL_FAILURE, 0) - runner.assertTransferCount(CryptographicHashContent.REL_SUCCESS, 1) - - final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashContent.REL_SUCCESS) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = successfulFlowfiles.first() - String hashAttribute = "content_${algorithm.name}" - flowFile.assertAttributeExists(hashAttribute) - - String hashedContent = flowFile.getAttribute(hashAttribute) - logger.info("flowfile.${hashAttribute} = ${hashedContent}") - - assert hashedContent == EXPECTED_CONTENT_HASH - } - } - - @Test - void testShouldOverwriteExistingAttribute() { - // Arrange - final String NON_EMPTY_CONTENT = "apachenifi" - final String OLD_HASH_ATTRIBUTE_VALUE = "OLD VALUE" - - HashAlgorithm algorithm = HashAlgorithm.SHA256 - - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashContent()) - - final String EXPECTED_CONTENT_HASH = HashService.hashValue(algorithm, NON_EMPTY_CONTENT) - logger.info("Expected ${algorithm.name.padLeft(11)}: ${EXPECTED_CONTENT_HASH}") - - // Set the algorithm - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashContent.HASH_ALGORITHM, algorithm.name) - - // Insert the content in the mock flowfile (with an existing attribute) - def oldAttributes = [("content_${algorithm.name}".toString()): OLD_HASH_ATTRIBUTE_VALUE] - runner.enqueue(NON_EMPTY_CONTENT.getBytes(StandardCharsets.UTF_8), - oldAttributes) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashContent.REL_FAILURE, 0) - runner.assertTransferCount(CryptographicHashContent.REL_SUCCESS, 1) - - final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashContent.REL_SUCCESS) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = successfulFlowfiles.first() - String hashAttribute = "content_${algorithm.name}" - flowFile.assertAttributeExists(hashAttribute) - - String hashedContent = flowFile.getAttribute(hashAttribute) - logger.info("flowfile.${hashAttribute} = ${hashedContent}") - - assert hashedContent != OLD_HASH_ATTRIBUTE_VALUE - assert hashedContent == EXPECTED_CONTENT_HASH - } - - @Test - void testShouldRouteToFailureOnEmptyContent() { - // Arrange - def algorithms = HashAlgorithm.values() - - final String EMPTY_CONTENT = "" - - final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashContent()) - - algorithms.each { HashAlgorithm algorithm -> - final String EXPECTED_CONTENT_HASH = HashService.hashValueStreaming(algorithm, new ByteArrayInputStream(EMPTY_CONTENT.bytes)) - logger.info("Expected ${algorithm.name.padLeft(11)}: ${EXPECTED_CONTENT_HASH}") - - // Reset the processor - runner.clearProperties() - runner.clearProvenanceEvents() - runner.clearTransferState() - - // Set the failure property - logger.info("Setting fail when empty to true") - runner.setProperty(CryptographicHashContent.FAIL_WHEN_EMPTY, "true") - - // Set the algorithm - logger.info("Setting hash algorithm to ${algorithm.name}") - runner.setProperty(CryptographicHashContent.HASH_ALGORITHM, algorithm.name) - - // Insert the content in the mock flowfile - runner.enqueue(EMPTY_CONTENT.getBytes(StandardCharsets.UTF_8)) - - // Act - runner.run(1) - - // Assert - runner.assertTransferCount(CryptographicHashContent.REL_FAILURE, 1) - runner.assertTransferCount(CryptographicHashContent.REL_SUCCESS, 0) - - final List failedFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashContent.REL_FAILURE) - - // Extract the generated attributes from the flowfile - MockFlowFile flowFile = failedFlowfiles.first() - String hashAttribute = "content_${algorithm.name}" - flowFile.assertAttributeNotExists(hashAttribute) - } - } -} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/ParseSyslogGroovyTest.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/ParseSyslogGroovyTest.groovy deleted file mode 100644 index e34902ba13..0000000000 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/ParseSyslogGroovyTest.groovy +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.standard - -import org.apache.nifi.syslog.parsers.SyslogParser -import org.apache.nifi.util.TestRunner -import org.apache.nifi.util.TestRunners -import org.bouncycastle.util.encoders.Hex -import org.junit.After -import org.junit.Assert -import org.junit.Before -import org.junit.BeforeClass -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -@RunWith(JUnit4.class) -class ParseSyslogGroovyTest extends GroovyTestCase { - private static final Logger logger = LoggerFactory.getLogger(ParseSyslogGroovyTest.class) - - @BeforeClass - static void setUpOnce() throws Exception { - logger.metaClass.methodMissing = { String name, args -> - logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") - } - } - - @Before - void setUp() throws Exception { - } - - @After - void tearDown() throws Exception { - } - - @Test - void testShouldHandleZeroLengthUDP() throws Exception { - // Arrange - final ParseSyslog proc = new ParseSyslog() - final TestRunner runner = TestRunners.newTestRunner(proc) - runner.setProperty(ParseSyslog.CHARSET, ParseSyslog.CHARSET.defaultValue) - - // Inject a SyslogParser which will always return null - def nullEventParser = [parseEvent: { byte[] bytes, String sender -> - logger.mock("Regardless of input bytes: [${Hex.toHexString(bytes)}] and sender: [${sender}], this parser will return null") - return null - }] as SyslogParser - proc.parser = nullEventParser - - final int numMessages = 10 - - // Act - numMessages.times { - runner.enqueue("Doesn't matter what is enqueued here") - } - runner.run(numMessages) - - int numFailed = runner.getFlowFilesForRelationship(ParseSyslog.REL_FAILURE).size() - int numSuccess = runner.getFlowFilesForRelationship(ParseSyslog.REL_SUCCESS).size() - logger.info("Transferred " + numSuccess + " to SUCCESS and " + numFailed + " to FAILURE") - - // Assert - - // all messages should be transferred to invalid - Assert.assertEquals("Did not process all the messages", numMessages, numFailed) - } -} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/SplitXmlTest.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/SplitXmlTest.groovy deleted file mode 100644 index f04dca685e..0000000000 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/SplitXmlTest.groovy +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.standard - -import org.apache.nifi.util.TestRunner -import org.apache.nifi.util.TestRunners -import org.junit.After -import org.junit.Before -import org.junit.BeforeClass -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -import java.nio.file.Paths - - -@RunWith(JUnit4.class) -class SplitXmlTest extends GroovyTestCase { - private static final Logger logger = LoggerFactory.getLogger(SplitXmlTest.class) - - @BeforeClass - static void setUpOnce() throws Exception { - logger.metaClass.methodMissing = { String name, args -> - logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") - } - } - - @Before - void setUp() throws Exception { - - } - - @After - void tearDown() throws Exception { - - } - - @Test - void testShouldHandleXXEInTemplate() { - // Arrange - final String XXE_TEMPLATE_FILEPATH = "src/test/resources/xxe_template.xml" - final TestRunner runner = TestRunners.newTestRunner(new SplitXml()) - runner.setProperty(SplitXml.SPLIT_DEPTH, "3") - runner.enqueue(Paths.get(XXE_TEMPLATE_FILEPATH)) - - // Act - runner.run() - logger.info("SplitXML processor ran") - - // Assert - runner.assertAllFlowFilesTransferred(SplitXml.REL_FAILURE) - } - - @Test - void testShouldHandleRemoteCallXXE() { - // Arrange - final String XXE_TEMPLATE_FILEPATH = "src/test/resources/xxe_from_report.xml" - final TestRunner runner = TestRunners.newTestRunner(new SplitXml()) - runner.setProperty(SplitXml.SPLIT_DEPTH, "3") - runner.enqueue(Paths.get(XXE_TEMPLATE_FILEPATH)) - - // Act - runner.run() - logger.info("SplitXML processor ran") - - // Assert - runner.assertAllFlowFilesTransferred(SplitXml.REL_FAILURE) - } -} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestCalculateRecordStats.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestCalculateRecordStats.groovy deleted file mode 100644 index 9f6b5a0644..0000000000 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestCalculateRecordStats.groovy +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.nifi.processors.standard - -import org.apache.nifi.serialization.SimpleRecordSchema -import org.apache.nifi.serialization.record.MapRecord -import org.apache.nifi.serialization.record.MockRecordParser -import org.apache.nifi.serialization.record.RecordField -import org.apache.nifi.serialization.record.RecordFieldType -import org.apache.nifi.serialization.record.RecordSchema -import org.apache.nifi.util.TestRunner -import org.apache.nifi.util.TestRunners -import org.junit.Assert -import org.junit.Before -import org.junit.Test - -class TestCalculateRecordStats { - TestRunner runner - MockRecordParser recordParser - RecordSchema personSchema - - @Before - void setup() { - runner = TestRunners.newTestRunner(CalculateRecordStats.class) - recordParser = new MockRecordParser() - runner.addControllerService("recordReader", recordParser) - runner.setProperty(CalculateRecordStats.RECORD_READER, "recordReader") - runner.enableControllerService(recordParser) - runner.assertValid() - - recordParser.addSchemaField("id", RecordFieldType.INT) - List personFields = new ArrayList<>() - RecordField nameField = new RecordField("name", RecordFieldType.STRING.getDataType()) - RecordField ageField = new RecordField("age", RecordFieldType.INT.getDataType()) - RecordField sportField = new RecordField("sport", RecordFieldType.STRING.getDataType()) - personFields.add(nameField) - personFields.add(ageField) - personFields.add(sportField) - personSchema = new SimpleRecordSchema(personFields) - recordParser.addSchemaField("person", RecordFieldType.RECORD) - } - - @Test - void testNoNullOrEmptyRecordFields() { - def sports = [ "Soccer", "Soccer", "Soccer", "Football", "Football", "Basketball" ] - def expectedAttributes = [ - "recordStats.sport.Soccer": "3", - "recordStats.sport.Football": "2", - "recordStats.sport.Basketball": "1", - "recordStats.sport": "6", - "record.count": "6" - ] - - commonTest([ "sport": "/person/sport"], sports, expectedAttributes) - } - - @Test - void testWithNullFields() { - def sports = [ "Soccer", null, null, "Football", null, "Basketball" ] - def expectedAttributes = [ - "recordStats.sport.Soccer": "1", - "recordStats.sport.Football": "1", - "recordStats.sport.Basketball": "1", - "recordStats.sport": "3", - "record.count": "6" - ] - - commonTest([ "sport": "/person/sport"], sports, expectedAttributes) - } - - @Test - void testWithFilters() { - def sports = [ "Soccer", "Soccer", "Soccer", "Football", "Football", "Basketball" ] - def expectedAttributes = [ - "recordStats.sport.Soccer": "3", - "recordStats.sport.Basketball": "1", - "recordStats.sport": "4", - "record.count": "6" - ] - - def propz = [ - "sport": "/person/sport[. != 'Football']" - ] - - commonTest(propz, sports, expectedAttributes) - } - - @Test - void testWithSizeLimit() { - runner.setProperty(CalculateRecordStats.LIMIT, "3") - def sports = [ "Soccer", "Soccer", "Soccer", "Football", "Football", - "Basketball", "Baseball", "Baseball", "Baseball", "Baseball", - "Skiing", "Skiing", "Skiing", "Snowboarding" - ] - def expectedAttributes = [ - "recordStats.sport.Skiing": "3", - "recordStats.sport.Soccer": "3", - "recordStats.sport.Baseball": "4", - "recordStats.sport": String.valueOf(sports.size()), - "record.count": String.valueOf(sports.size()) - ] - - def propz = [ - "sport": "/person/sport" - ] - - commonTest(propz, sports, expectedAttributes) - } - - private void commonTest(Map procProperties, List sports, Map expectedAttributes) { - int index = 1 - sports.each { sport -> - recordParser.addRecord(index++, new MapRecord(personSchema, [ - "name" : "John Doe", - "age" : 48, - "sport": sport - ])) - } - - procProperties.each { kv -> - runner.setProperty(kv.key, kv.value) - } - - runner.enqueue("") - runner.run() - runner.assertTransferCount(CalculateRecordStats.REL_FAILURE, 0) - runner.assertTransferCount(CalculateRecordStats.REL_SUCCESS, 1) - - def flowFiles = runner.getFlowFilesForRelationship(CalculateRecordStats.REL_SUCCESS) - def ff = flowFiles[0] - expectedAttributes.each { kv -> - Assert.assertNotNull("Missing ${kv.key}", ff.getAttribute(kv.key)) - Assert.assertEquals(kv.value, ff.getAttribute(kv.key)) - } - } -} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestEncryptContentGroovy.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestEncryptContentGroovy.groovy deleted file mode 100644 index 0e711e9f0d..0000000000 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestEncryptContentGroovy.groovy +++ /dev/null @@ -1,943 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License") you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.standard - -import groovy.time.TimeCategory -import groovy.time.TimeDuration -import org.apache.commons.codec.binary.Hex -import org.apache.nifi.components.ValidationResult -import org.apache.nifi.security.util.EncryptionMethod -import org.apache.nifi.security.util.KeyDerivationFunction -import org.apache.nifi.security.util.crypto.Argon2CipherProvider -import org.apache.nifi.security.util.crypto.Argon2SecureHasher -import org.apache.nifi.security.util.crypto.CipherUtility -import org.apache.nifi.security.util.crypto.KeyedEncryptor -import org.apache.nifi.security.util.crypto.PasswordBasedEncryptor -import org.apache.nifi.security.util.crypto.RandomIVPBECipherProvider -import org.apache.nifi.util.MockFlowFile -import org.apache.nifi.util.MockProcessContext -import org.apache.nifi.util.TestRunner -import org.apache.nifi.util.TestRunners -import org.bouncycastle.jce.provider.BouncyCastleProvider -import org.junit.After -import org.junit.Assert -import org.junit.Before -import org.junit.BeforeClass -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -import javax.crypto.Cipher -import java.nio.charset.StandardCharsets -import java.nio.file.Paths -import java.security.Security -import java.text.SimpleDateFormat -import java.time.Instant -import java.time.temporal.ChronoUnit - -@RunWith(JUnit4.class) -class TestEncryptContentGroovy { - private static final Logger logger = LoggerFactory.getLogger(TestEncryptContentGroovy.class) - - private static final String WEAK_CRYPTO_ALLOWED = EncryptContent.WEAK_CRYPTO_ALLOWED_NAME - private static final String WEAK_CRYPTO_NOT_ALLOWED = EncryptContent.WEAK_CRYPTO_NOT_ALLOWED_NAME - - private static final List SUPPORTED_KEYED_ENCRYPTION_METHODS = EncryptionMethod.values().findAll { it.isKeyedCipher() && it != EncryptionMethod.AES_CBC_NO_PADDING } - - @BeforeClass - static void setUpOnce() throws Exception { - Security.addProvider(new BouncyCastleProvider()) - - logger.metaClass.methodMissing = { String name, args -> - logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") - } - } - - @Before - void setUp() throws Exception { - } - - @After - void tearDown() throws Exception { - } - - @Test - void testShouldValidateMaxKeySizeForAlgorithmsOnUnlimitedStrengthJVM() throws IOException { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class) - Collection results - MockProcessContext pc - - EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC - - // Integer.MAX_VALUE or 128, so use 256 or 128 - final int MAX_KEY_LENGTH = [PasswordBasedEncryptor.getMaxAllowedKeyLength(encryptionMethod.algorithm), 256].min() - final String TOO_LONG_KEY_HEX = "ab" * (MAX_KEY_LENGTH / 8 + 1) - logger.info("Using key ${TOO_LONG_KEY_HEX} (${TOO_LONG_KEY_HEX.length() * 4} bits)") - - runner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()) - runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NONE.name()) - runner.setProperty(EncryptContent.RAW_KEY_HEX, TOO_LONG_KEY_HEX) - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - Assert.assertEquals(1, results.size()) - logger.expected(results) - ValidationResult vr = results.first() - - String expectedResult = "'raw-key-hex' is invalid because Key must be valid length [128, 192, 256]" - String message = "'" + vr.toString() + "' contains '" + expectedResult + "'" - Assert.assertTrue(message, vr.toString().contains(expectedResult)) - } - - @Test - void testShouldValidateKeyFormatAndSizeForAlgorithms() throws IOException { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class) - Collection results - MockProcessContext pc - - EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC - - final int INVALID_KEY_LENGTH = 120 - final String INVALID_KEY_HEX = "ab" * (INVALID_KEY_LENGTH / 8) - logger.info("Using key ${INVALID_KEY_HEX} (${INVALID_KEY_HEX.length() * 4} bits)") - - runner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()) - runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NONE.name()) - runner.setProperty(EncryptContent.RAW_KEY_HEX, INVALID_KEY_HEX) - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - Assert.assertEquals(1, results.size()) - logger.expected(results) - ValidationResult keyLengthInvalidVR = results.first() - - String expectedResult = "'raw-key-hex' is invalid because Key must be valid length [128, 192, 256]" - String message = "'" + keyLengthInvalidVR.toString() + "' contains '" + expectedResult + "'" - Assert.assertTrue(message, keyLengthInvalidVR.toString().contains(expectedResult)) - } - - @Test - void testShouldValidateKDFWhenKeyedCipherSelected() { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class) - Collection results - MockProcessContext pc - - final int VALID_KEY_LENGTH = 128 - final String VALID_KEY_HEX = "ab" * (VALID_KEY_LENGTH / 8) - logger.info("Using key ${VALID_KEY_HEX} (${VALID_KEY_HEX.length() * 4} bits)") - - runner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - - SUPPORTED_KEYED_ENCRYPTION_METHODS.each { EncryptionMethod encryptionMethod -> - logger.info("Trying encryption method ${encryptionMethod.name()}") - runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()) - - // Scenario 1: Legacy KDF + keyed cipher -> validation error - final def INVALID_KDFS = [KeyDerivationFunction.NIFI_LEGACY, KeyDerivationFunction.OPENSSL_EVP_BYTES_TO_KEY] - INVALID_KDFS.each { KeyDerivationFunction invalidKDF -> - logger.info("Trying KDF ${invalidKDF.name()}") - - runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, invalidKDF.name()) - runner.setProperty(EncryptContent.RAW_KEY_HEX, VALID_KEY_HEX) - runner.removeProperty(EncryptContent.PASSWORD) - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - logger.expected(results) - assert results.size() == 1 - ValidationResult keyLengthInvalidVR = results.first() - - String expectedResult = "'key-derivation-function' is invalid because Key Derivation Function is required to be BCRYPT, SCRYPT, PBKDF2, ARGON2, NONE when using " + - "algorithm ${encryptionMethod.algorithm}" - String message = "'" + keyLengthInvalidVR.toString() + "' contains '" + expectedResult + "'" - assert keyLengthInvalidVR.toString().contains(expectedResult) - } - - // Scenario 2: No KDF + keyed cipher + raw-key-hex -> valid - def none = KeyDerivationFunction.NONE - logger.info("Trying KDF ${none.name()}") - - runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, none.name()) - runner.setProperty(EncryptContent.RAW_KEY_HEX, VALID_KEY_HEX) - runner.removeProperty(EncryptContent.PASSWORD) - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - assert results.isEmpty() - - // Scenario 3: Strong KDF + keyed cipher + password -> valid - final def VALID_KDFS = [KeyDerivationFunction.BCRYPT, KeyDerivationFunction.SCRYPT, KeyDerivationFunction.PBKDF2, KeyDerivationFunction.ARGON2] - VALID_KDFS.each { KeyDerivationFunction validKDF -> - logger.info("Trying KDF ${validKDF.name()}") - - runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, validKDF.name()) - runner.setProperty(EncryptContent.PASSWORD, "thisIsABadPassword") - runner.removeProperty(EncryptContent.RAW_KEY_HEX) - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - assert results.isEmpty() - } - } - } - - @Test - void testKDFShouldDefaultToNone() { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class) - Collection results - MockProcessContext pc - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - String defaultKDF = pc.getProperty("key-derivation-function").getValue() - - // Assert - assert defaultKDF == KeyDerivationFunction.NONE.name() - } - - @Test - void testEMShouldDefaultToAES_GCM() { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class) - Collection results - MockProcessContext pc - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - String defaultEM = pc.getProperty("Encryption Algorithm").getValue() - - // Assert - assert defaultEM == EncryptionMethod.AES_GCM.name() - } - - @Test - void testShouldValidateKeyMaterialSourceWhenKeyedCipherSelected() { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class) - Collection results - MockProcessContext pc - - logger.info("Testing keyed encryption methods: ${SUPPORTED_KEYED_ENCRYPTION_METHODS*.name()}") - - final int VALID_KEY_LENGTH = 128 - final String VALID_KEY_HEX = "ab" * (VALID_KEY_LENGTH / 8) - logger.info("Using key ${VALID_KEY_HEX} (${VALID_KEY_HEX.length() * 4} bits)") - - final String VALID_PASSWORD = "thisIsABadPassword" - logger.info("Using password ${VALID_PASSWORD} (${VALID_PASSWORD.length()} bytes)") - - runner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - KeyDerivationFunction none = KeyDerivationFunction.NONE - final def VALID_KDFS = KeyDerivationFunction.values().findAll { it.isStrongKDF() } - - // Scenario 1 - RKH w/ KDF NONE & em in [CBC, CTR, GCM] (no password) - SUPPORTED_KEYED_ENCRYPTION_METHODS.each { EncryptionMethod kem -> - logger.info("Trying encryption method ${kem.name()} with KDF ${none.name()}") - runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, kem.name()) - runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, none.name()) - - logger.info("Setting raw key hex: ${VALID_KEY_HEX}") - runner.setProperty(EncryptContent.RAW_KEY_HEX, VALID_KEY_HEX) - runner.removeProperty(EncryptContent.PASSWORD) - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - assert results.isEmpty() - - // Scenario 2 - PW w/ KDF in [BCRYPT, SCRYPT, PBKDF2, ARGON2] & em in [CBC, CTR, GCM] (no RKH) - VALID_KDFS.each { KeyDerivationFunction kdf -> - logger.info("Trying encryption method ${kem.name()} with KDF ${kdf.name()}") - runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, kem.name()) - runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, kdf.name()) - - logger.info("Setting password: ${VALID_PASSWORD}") - runner.removeProperty(EncryptContent.RAW_KEY_HEX) - runner.setProperty(EncryptContent.PASSWORD, VALID_PASSWORD) - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - assert results.isEmpty() - } - } - } - - @Test - void testShouldValidateKDFWhenPBECipherSelected() { - // Arrange - final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class) - Collection results - MockProcessContext pc - final String PASSWORD = "short" - - def encryptionMethods = EncryptionMethod.values().findAll { it.algorithm.startsWith("PBE") } - - runner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - runner.setProperty(EncryptContent.PASSWORD, PASSWORD) - runner.setProperty(EncryptContent.ALLOW_WEAK_CRYPTO, WEAK_CRYPTO_ALLOWED) - - encryptionMethods.each { EncryptionMethod encryptionMethod -> - logger.info("Trying encryption method ${encryptionMethod.name()}") - runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()) - - final def INVALID_KDFS = [KeyDerivationFunction.NONE, KeyDerivationFunction.BCRYPT, KeyDerivationFunction.SCRYPT, KeyDerivationFunction.PBKDF2, KeyDerivationFunction.ARGON2] - INVALID_KDFS.each { KeyDerivationFunction invalidKDF -> - logger.info("Trying KDF ${invalidKDF.name()}") - - runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, invalidKDF.name()) - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - logger.expected(results) - Assert.assertEquals(1, results.size()) - ValidationResult keyLengthInvalidVR = results.first() - - String expectedResult = "'Key Derivation Function' is invalid because Key Derivation Function is required to be NIFI_LEGACY, OPENSSL_EVP_BYTES_TO_KEY when using " + - "algorithm ${encryptionMethod.algorithm}" - String message = "'" + keyLengthInvalidVR.toString() + "' contains '" + expectedResult + "'" - Assert.assertTrue(message, keyLengthInvalidVR.toString().contains(expectedResult)) - } - - final def VALID_KDFS = [KeyDerivationFunction.NIFI_LEGACY, KeyDerivationFunction.OPENSSL_EVP_BYTES_TO_KEY] - VALID_KDFS.each { KeyDerivationFunction validKDF -> - logger.info("Trying KDF ${validKDF.name()}") - - runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, validKDF.name()) - - runner.enqueue(new byte[0]) - pc = (MockProcessContext) runner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - Assert.assertEquals(0, results.size()) - } - } - } - - @Test - void testRoundTrip() throws IOException { - final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()) - final String RAW_KEY_HEX = "ab" * 16 - testRunner.setProperty(EncryptContent.RAW_KEY_HEX, RAW_KEY_HEX) - testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NONE.name()) - - SUPPORTED_KEYED_ENCRYPTION_METHODS.each { EncryptionMethod encryptionMethod -> - logger.info("Attempting {}", encryptionMethod.name()) - testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()) - testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - - testRunner.enqueue(Paths.get("src/test/resources/hello.txt")) - testRunner.clearTransferState() - testRunner.run() - - testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1) - - MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0) - testRunner.assertQueueEmpty() - - testRunner.setProperty(EncryptContent.MODE, EncryptContent.DECRYPT_MODE) - testRunner.enqueue(flowFile) - testRunner.clearTransferState() - testRunner.run() - testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1) - - logger.info("Successfully decrypted {}", encryptionMethod.name()) - - flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0) - flowFile.assertContentEquals(new File("src/test/resources/hello.txt")) - } - } - - @Test - void testDecryptAesCbcNoPadding() { - final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()) - final String RAW_KEY_HEX = "ab" * 16 - testRunner.setProperty(EncryptContent.RAW_KEY_HEX, RAW_KEY_HEX) - testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NONE.name()) - testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, EncryptionMethod.AES_CBC_NO_PADDING.name()) - testRunner.setProperty(EncryptContent.MODE, EncryptContent.DECRYPT_MODE) - - final String content = "ExactBlockSizeRequiredForProcess" - final byte[] bytes = content.getBytes(StandardCharsets.UTF_8) - final ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes) - final ByteArrayOutputStream outputStream = new ByteArrayOutputStream() - - final KeyedEncryptor encryptor = new KeyedEncryptor(EncryptionMethod.AES_CBC_NO_PADDING, Hex.decodeHex(RAW_KEY_HEX)) - encryptor.encryptionCallback.process(inputStream, outputStream) - outputStream.close() - - final byte[] encrypted = outputStream.toByteArray() - testRunner.enqueue(encrypted) - testRunner.run() - - testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1) - MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0) - flowFile.assertContentEquals(content) - } - - // TODO: Implement - @Test - void testArgon2EncryptionShouldWriteAttributesWithEncryptionMetadata() throws IOException { - // Arrange - final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()) - KeyDerivationFunction kdf = KeyDerivationFunction.ARGON2 - EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC - logger.info("Attempting encryption with {}", encryptionMethod.name()) - - testRunner.setProperty(EncryptContent.PASSWORD, "thisIsABadPassword") - testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, kdf.name()) - testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()) - testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - - String PLAINTEXT = "This is a plaintext message. " - - // Act - testRunner.enqueue(PLAINTEXT) - testRunner.clearTransferState() - testRunner.run() - - // Assert - testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1) - logger.info("Successfully encrypted with {}", encryptionMethod.name()) - - MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0) - testRunner.assertQueueEmpty() - - printFlowFileAttributes(flowFile.getAttributes()) - - byte[] flowfileContentBytes = flowFile.getData() - String flowfileContent = flowFile.getContent() - - int ivDelimiterStart = CipherUtility.findSequence(flowfileContentBytes, RandomIVPBECipherProvider.IV_DELIMITER) - logger.info("IV delimiter starts at ${ivDelimiterStart}") - - final byte[] EXPECTED_KDF_SALT_BYTES = extractFullSaltFromCipherBytes(flowfileContentBytes) - final String EXPECTED_KDF_SALT = new String(EXPECTED_KDF_SALT_BYTES) - final String EXPECTED_SALT_HEX = extractRawSaltHexFromFullSalt(EXPECTED_KDF_SALT_BYTES, kdf) - logger.info("Extracted expected raw salt (hex): ${EXPECTED_SALT_HEX}") - - final String EXPECTED_IV_HEX = Hex.encodeHexString(flowfileContentBytes[(ivDelimiterStart - 16).. attributes) { - int maxLength = attributes.keySet()*.length().max() - attributes.sort().each { attr, value -> - logger.info("Attribute: ${attr.padRight(maxLength)}: ${value}") - } - } - - @Test - void testKeyedEncryptionShouldWriteAttributesWithEncryptionMetadata() throws IOException { - // Arrange - final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()) - KeyDerivationFunction kdf = KeyDerivationFunction.NONE - EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC - logger.info("Attempting encryption with {}", encryptionMethod.name()) - - testRunner.setProperty(EncryptContent.RAW_KEY_HEX, "0123456789ABCDEFFEDCBA9876543210") - testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, kdf.name()) - testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()) - testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - - String PLAINTEXT = "This is a plaintext message. " - - // Act - testRunner.enqueue(PLAINTEXT) - testRunner.clearTransferState() - testRunner.run() - - // Assert - testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1) - logger.info("Successfully encrypted with {}", encryptionMethod.name()) - - MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0) - testRunner.assertQueueEmpty() - - printFlowFileAttributes(flowFile.getAttributes()) - - byte[] flowfileContentBytes = flowFile.getData() - String flowfileContent = flowFile.getContent() - logger.info("Cipher text (${flowfileContentBytes.length}): ${Hex.encodeHexString(flowfileContentBytes)}") - - int ivDelimiterStart = CipherUtility.findSequence(flowfileContentBytes, RandomIVPBECipherProvider.IV_DELIMITER) - logger.info("IV delimiter starts at ${ivDelimiterStart}") - assert ivDelimiterStart == 16 - - def diff = calculateTimestampDifference(new Date(), flowFile.getAttribute("encryptcontent.timestamp")) - logger.info("Timestamp difference: ${diff}") - - // Assert the timestamp attribute was written and is accurate - assert diff.toMilliseconds() < 1_000 - - final String EXPECTED_IV_HEX = Hex.encodeHexString(flowfileContentBytes[0.. (ms): ${dateMillis}") - Date parsedTimestamp = formatter.parse(timestamp) - long parsedTimestampMillis = parsedTimestamp.toInstant().toEpochMilli() - logger.info("Parsed timestamp ${timestamp} -> (ms): ${parsedTimestampMillis}") - - TimeCategory.minus(date, parsedTimestamp) - } - - static byte[] extractFullSaltFromCipherBytes(byte[] cipherBytes) { - int saltDelimiterStart = CipherUtility.findSequence(cipherBytes, RandomIVPBECipherProvider.SALT_DELIMITER) - logger.info("Salt delimiter starts at ${saltDelimiterStart}") - byte[] saltBytes = cipherBytes[0.. results - MockProcessContext pc - - def encryptionMethods = EncryptionMethod.values().findAll { it.algorithm.startsWith("PBE") } - - boolean limitedStrengthCrypto = false - boolean allowWeakCrypto = false - testRunner.setProperty(EncryptContent.ALLOW_WEAK_CRYPTO, WEAK_CRYPTO_NOT_ALLOWED) - - // Use .find instead of .each to allow "breaks" using return false - encryptionMethods.find { EncryptionMethod encryptionMethod -> - // Determine the minimum of the algorithm-accepted length or the global safe minimum to ensure only one validation result - def shortPasswordLength = [PasswordBasedEncryptor.getMinimumSafePasswordLength() - 1, CipherUtility.getMaximumPasswordLengthForAlgorithmOnLimitedStrengthCrypto(encryptionMethod) - 1].min() - String shortPassword = "x" * shortPasswordLength - if (encryptionMethod.isUnlimitedStrength() || encryptionMethod.isKeyedCipher()) { - return false - // cannot test unlimited strength in unit tests because it's not enabled by the JVM by default. - } - - testRunner.setProperty(EncryptContent.PASSWORD, shortPassword) - logger.info("Attempting ${encryptionMethod.algorithm} with password of length ${shortPasswordLength}") - logger.state("Limited strength crypto ${limitedStrengthCrypto} and allow weak crypto: ${allowWeakCrypto}") - testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()) - testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - - testRunner.clearTransferState() - testRunner.enqueue(new byte[0]) - pc = (MockProcessContext) testRunner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - logger.expected(results) - Assert.assertEquals(1, results.size()) - ValidationResult passwordLengthVR = results.first() - - String expectedResult = "'Password' is invalid because Password length less than ${PasswordBasedEncryptor.getMinimumSafePasswordLength()} characters is potentially unsafe. " + - "See Admin Guide." - String message = "'" + passwordLengthVR.toString() + "' contains '" + expectedResult + "'" - Assert.assertTrue(message, passwordLengthVR.toString().contains(expectedResult)) - } - } - - @Test - void testShouldNotCheckLengthOfPasswordWhenAllowed() throws IOException { - // Arrange - final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()) - testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NIFI_LEGACY.name()) - - Collection results - MockProcessContext pc - - def encryptionMethods = EncryptionMethod.values().findAll { it.algorithm.startsWith("PBE") } - - boolean limitedStrengthCrypto = false - boolean allowWeakCrypto = true - testRunner.setProperty(EncryptContent.ALLOW_WEAK_CRYPTO, WEAK_CRYPTO_ALLOWED) - - // Use .find instead of .each to allow "breaks" using return false - encryptionMethods.find { EncryptionMethod encryptionMethod -> - // Determine the minimum of the algorithm-accepted length or the global safe minimum to ensure only one validation result - def shortPasswordLength = [PasswordBasedEncryptor.getMinimumSafePasswordLength() - 1, CipherUtility.getMaximumPasswordLengthForAlgorithmOnLimitedStrengthCrypto(encryptionMethod) - 1].min() - String shortPassword = "x" * shortPasswordLength - if (encryptionMethod.isUnlimitedStrength() || encryptionMethod.isKeyedCipher()) { - return false - // cannot test unlimited strength in unit tests because it's not enabled by the JVM by default. - } - - testRunner.setProperty(EncryptContent.PASSWORD, shortPassword) - logger.info("Attempting ${encryptionMethod.algorithm} with password of length ${shortPasswordLength}") - logger.state("Limited strength crypto ${limitedStrengthCrypto} and allow weak crypto: ${allowWeakCrypto}") - testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()) - testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - - testRunner.clearTransferState() - testRunner.enqueue(new byte[0]) - pc = (MockProcessContext) testRunner.getProcessContext() - - // Act - results = pc.validate() - - // Assert - Assert.assertEquals(results.toString(), 0, results.size()) - } - } - - @Test - void testPGPPasswordShouldSupportExpressionLanguage() throws IOException { - // Arrange - final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()) - testRunner.setProperty(EncryptContent.MODE, EncryptContent.DECRYPT_MODE) - testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, EncryptionMethod.PGP.name()) - testRunner.setProperty(EncryptContent.PRIVATE_KEYRING, "src/test/resources/TestEncryptContent/secring.gpg") - - Collection results - MockProcessContext pc - - // Verify this is the correct password - final String passphraseWithoutEL = "thisIsABadPassword" - testRunner.setProperty(EncryptContent.PRIVATE_KEYRING_PASSPHRASE, passphraseWithoutEL) - - testRunner.clearTransferState() - testRunner.enqueue(new byte[0]) - pc = (MockProcessContext) testRunner.getProcessContext() - - results = pc.validate() - Assert.assertEquals(results.toString(), 0, results.size()) - - final String passphraseWithEL = "\${literal('thisIsABadPassword')}" - testRunner.setProperty(EncryptContent.PRIVATE_KEYRING_PASSPHRASE, passphraseWithEL) - - testRunner.clearTransferState() - testRunner.enqueue(new byte[0]) - - // Act - results = pc.validate() - - // Assert - Assert.assertEquals(results.toString(), 0, results.size()) - } - - @Test - void testArgon2ShouldIncludeFullSalt() throws IOException { - // Arrange - final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()) - testRunner.setProperty(EncryptContent.PASSWORD, "thisIsABadPassword") - testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.ARGON2.name()) - - EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC - - logger.info("Attempting {}", encryptionMethod.name()) - testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()) - testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE) - - // Act - testRunner.enqueue(Paths.get("src/test/resources/hello.txt")) - testRunner.clearTransferState() - testRunner.run() - - // Assert - testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1) - - MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0) - testRunner.assertQueueEmpty() - - def flowFileContent = flowFile.getContent() - logger.info("Flowfile content (${flowFile.getData().length}): ${Hex.encodeHexString(flowFile.getData())}") - - def fullSalt = flowFileContent.substring(0, flowFileContent.indexOf(new String(RandomIVPBECipherProvider.SALT_DELIMITER, StandardCharsets.UTF_8))) - logger.info("Full salt (${fullSalt.size()}): ${fullSalt}") - - boolean isValidFormattedSalt = Argon2CipherProvider.isArgon2FormattedSalt(fullSalt) - logger.info("Salt is Argon2 format: ${isValidFormattedSalt}") - assert isValidFormattedSalt - - def FULL_SALT_LENGTH_RANGE = (49..57) - boolean fullSaltIsValidLength = FULL_SALT_LENGTH_RANGE.contains(fullSalt.bytes.length) - logger.info("Salt length (${fullSalt.length()}) in valid range (${FULL_SALT_LENGTH_RANGE})") - assert fullSaltIsValidLength - } -} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestFlattenJson.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestFlattenJson.groovy deleted file mode 100644 index 2485dc0b22..0000000000 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestFlattenJson.groovy +++ /dev/null @@ -1,441 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License") you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.nifi.processors.standard - -import groovy.json.JsonSlurper -import org.apache.nifi.util.TestRunners -import org.junit.Assert -import org.junit.Test -import static groovy.json.JsonOutput.prettyPrint -import static groovy.json.JsonOutput.toJson - -class TestFlattenJson { - @Test - void testFlatten() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - test: [ - msg: "Hello, world" - ], - first: [ - second: [ - third: [ - "one", "two", "three", "four", "five" - ] - ] - ] - ])) - baseTest(testRunner, json, 2) { parsed -> - Assert.assertEquals("test.msg should exist, but doesn't", parsed["test.msg"], "Hello, world") - Assert.assertEquals("Three level block doesn't exist.", parsed["first.second.third"], [ - "one", "two", "three", "four", "five" - ]) - } - } - - void baseTest(testRunner, String json, int keyCount, Closure c) { - baseTest(testRunner, json, [:], keyCount, c); - } - - void baseTest(def testRunner, String json, Map attrs, int keyCount, Closure c) { - testRunner.enqueue(json, attrs) - testRunner.run(1, true) - testRunner.assertTransferCount(FlattenJson.REL_FAILURE, 0) - testRunner.assertTransferCount(FlattenJson.REL_SUCCESS, 1) - - def flowFiles = testRunner.getFlowFilesForRelationship(FlattenJson.REL_SUCCESS) - def content = testRunner.getContentAsByteArray(flowFiles[0]) - def asJson = new String(content) - def slurper = new JsonSlurper() - def parsed = slurper.parseText(asJson) as Map - - Assert.assertEquals("Too many keys", keyCount, parsed.size()) - c.call(parsed) - } - - @Test - void testFlattenRecordSet() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - [ - first: [ - second: "Hello" - ] - ], - [ - first: [ - second: "World" - ] - ] - ])) - - def expected = ["Hello", "World"] - baseTest(testRunner, json, 2) { parsed -> - Assert.assertTrue("Not a list", parsed instanceof List) - 0.upto(parsed.size() - 1) { - Assert.assertEquals("Missing values.", parsed[it]["first.second"], expected[it]) - } - } - } - - @Test - void testDifferentSeparator() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - first: [ - second: [ - third: [ - "one", "two", "three", "four", "five" - ] - ] - ] - ])) - testRunner.setProperty(FlattenJson.SEPARATOR, "_") - baseTest(testRunner, json, 1) { parsed -> - Assert.assertEquals("Separator not applied.", parsed["first_second_third"], [ - "one", "two", "three", "four", "five" - ]) - } - } - - @Test - void testExpressionLanguage() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - first: [ - second: [ - third: [ - "one", "two", "three", "four", "five" - ] - ] - ] - ])) - - testRunner.setValidateExpressionUsage(true); - testRunner.setProperty(FlattenJson.SEPARATOR, '${separator.char}') - baseTest(testRunner, json, ["separator.char": "_"], 1) { parsed -> - Assert.assertEquals("Separator not applied.", parsed["first_second_third"], [ - "one", "two", "three", "four", "five" - ]) - } - } - - @Test - void testFlattenModeNormal() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - first: [ - second: [ - third: [ - "one", "two", "three", "four", "five" - ] - ] - ] - ])) - - testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_NORMAL) - baseTest(testRunner, json,5) { parsed -> - Assert.assertEquals("Separator not applied.", "one", parsed["first.second.third[0]"]) - } - } - - @Test - void testFlattenModeKeepArrays() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - first: [ - second: [ - [ - x: 1, - y: 2, - z: [3, 4, 5] - ], - [ 6, 7, 8], - [ - [9, 10], - 11, - 12 - ] - ], - "third" : [ - a: "b", - c: "d", - e: "f" - ] - ] - ])) - - testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_KEEP_ARRAYS) - baseTest(testRunner, json,4) { parsed -> - assert parsed["first.second"] instanceof List // [{x=1, y=2, z=[3, 4, 5]}, [6, 7, 8], [[9, 10], 11, 12]] - assert parsed["first.second"][1] == [6, 7, 8] - Assert.assertEquals("Separator not applied.", "b", parsed["first.third.a"]) - } - } - - @Test - void testFlattenModeKeepPrimitiveArrays() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - first: [ - second: [ - [ - x: 1, - y: 2, - z: [3, 4, 5] - ], - [ 6, 7, 8], - [ - [9, 10], - 11, - 12 - ] - ], - "third" : [ - a: "b", - c: "d", - e: "f" - ] - ] - ])) - - testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_KEEP_PRIMITIVE_ARRAYS) - baseTest(testRunner, json,10) { parsed -> - Assert.assertEquals("Separator not applied.", 1, parsed["first.second[0].x"]) - Assert.assertEquals("Separator not applied.", [3, 4, 5], parsed["first.second[0].z"]) - Assert.assertEquals("Separator not applied.", [9, 10], parsed["first.second[2][0]"]) - Assert.assertEquals("Separator not applied.", 11, parsed["first.second[2][1]"]) - Assert.assertEquals("Separator not applied.", 12, parsed["first.second[2][2]"]) - Assert.assertEquals("Separator not applied.", "d", parsed["first.third.c"]) - } - } - - @Test - void testFlattenModeDotNotation() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - first: [ - second: [ - third: [ - "one", "two", "three", "four", "five" - ] - ] - ] - ])) - - testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_DOT_NOTATION) - baseTest(testRunner, json,5) { parsed -> - Assert.assertEquals("Separator not applied.", "one", parsed["first.second.third.0"]) - } - } - - @Test - void testFlattenSlash() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - first: [ - second: [ - third: [ - "http://localhost/value1", "http://localhost/value2" - ] - ] - ] - ])) - - testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_NORMAL) - baseTest(testRunner, json,2) { parsed -> - Assert.assertEquals("Separator not applied.", "http://localhost/value1", parsed["first.second.third[0]"]) - } - } - - @Test - void testEscapeForJson() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ name: "José" - ])) - - testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_NORMAL) - baseTest(testRunner, json,1) { parsed -> - Assert.assertEquals("Separator not applied.", "José", parsed["name"]) - } - } - - @Test - void testUnFlatten() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - "test.msg": "Hello, world", - "first.second.third": [ "one", "two", "three", "four", "five" ] - ])) - - testRunner.setProperty(FlattenJson.RETURN_TYPE, FlattenJson.RETURN_TYPE_UNFLATTEN) - baseTest(testRunner, json, 2) { parsed -> - assert parsed.test instanceof Map - assert parsed.test.msg == "Hello, world" - assert parsed.first.second.third == [ "one", "two", "three", "four", "five" ] - } - } - - @Test - void testUnFlattenWithDifferentSeparator() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - "first_second_third": [ "one", "two", "three", "four", "five" ] - ])) - - testRunner.setProperty(FlattenJson.SEPARATOR, "_") - testRunner.setProperty(FlattenJson.RETURN_TYPE, FlattenJson.RETURN_TYPE_UNFLATTEN) - baseTest(testRunner, json, 1) { parsed -> - assert parsed.first instanceof Map - assert parsed.first.second.third == [ "one", "two", "three", "four", "five" ] - } - } - - @Test - void testUnFlattenForKeepArraysMode() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - "a.b": 1, - "a.c": [ - false, - ["i.j": [ false, true, "xy" ] ] - ] - ])) - - testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_KEEP_ARRAYS) - testRunner.setProperty(FlattenJson.RETURN_TYPE, FlattenJson.RETURN_TYPE_UNFLATTEN) - baseTest(testRunner, json, 1) { parsed -> - assert parsed.a instanceof Map - assert parsed.a.b == 1 - assert parsed.a.c[0] == false - assert parsed.a.c[1].i instanceof Map - assert parsed.a.c[1].i.j == [false, true, "xy"] - } - } - - @Test - void testUnFlattenForKeepPrimitiveArraysMode() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - "first.second[0].x": 1, - "first.second[0].y": 2, - "first.second[0].z": [3, 4, 5], - "first.second[1]": [6, 7, 8], - "first.second[2][0]": [9, 10], - "first.second[2][1]": 11, - "first.second[2][2]": 12, - "first.third.a": "b", - "first.third.c": "d", - "first.third.e": "f" - ])) - - testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_KEEP_PRIMITIVE_ARRAYS) - testRunner.setProperty(FlattenJson.RETURN_TYPE, FlattenJson.RETURN_TYPE_UNFLATTEN) - baseTest(testRunner, json, 1) { parsed -> - assert parsed.first instanceof Map - assert parsed.first.second[0].x == 1 - assert parsed.first.second[2][0] == [9, 10] - assert parsed.first.third.c == "d" - } - } - - @Test - void testUnFlattenForDotNotationMode() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - "first.second.third.0": ["one", "two", "three", "four", "five"] - ])) - - testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_DOT_NOTATION) - testRunner.setProperty(FlattenJson.RETURN_TYPE, FlattenJson.RETURN_TYPE_UNFLATTEN) - baseTest(testRunner, json,1) { parsed -> - assert parsed.first instanceof Map - assert parsed.first.second.third[0] == ["one", "two", "three", "four", "five"] - } - } - - @Test - void testFlattenWithIgnoreReservedCharacters() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - "first": [ - "second.third": "Hello", - "fourth" : "World" - ] - ])) - - testRunner.setProperty(FlattenJson.IGNORE_RESERVED_CHARACTERS, "true") - - baseTest(testRunner, json, 2) { parsed -> - Assert.assertEquals("Separator not applied.", parsed["first.second.third"], "Hello") - Assert.assertEquals("Separator not applied.", parsed["first.fourth"], "World") - } - } - - @Test - void testFlattenRecordSetWithIgnoreReservedCharacters() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - [ - "first": [ - "second_third": "Hello" - ] - ], - [ - "first": [ - "second_third": "World" - ] - ] - ])) - testRunner.setProperty(FlattenJson.SEPARATOR, "_") - testRunner.setProperty(FlattenJson.IGNORE_RESERVED_CHARACTERS, "true") - - def expected = ["Hello", "World"] - baseTest(testRunner, json, 2) { parsed -> - Assert.assertTrue("Not a list", parsed instanceof List) - 0.upto(parsed.size() - 1) { - Assert.assertEquals("Missing values.", parsed[it]["first_second_third"], expected[it]) - } - } - } - - @Test - void testFlattenModeNormalWithIgnoreReservedCharacters() { - def testRunner = TestRunners.newTestRunner(FlattenJson.class) - def json = prettyPrint(toJson([ - [ - "first": [ - "second_third": "Hello" - ] - ], - [ - "first": [ - "second_third": "World" - ] - ] - ])) - testRunner.setProperty(FlattenJson.SEPARATOR, "_") - testRunner.setProperty(FlattenJson.IGNORE_RESERVED_CHARACTERS, "true") - testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_NORMAL) - - baseTest(testRunner, json, 2) { parsed -> - Assert.assertEquals("Separator not applied.", "Hello", parsed["[0]_first_second_third"]) - Assert.assertEquals("Separator not applied.", "World", parsed["[1]_first_second_third"]) - } - } -} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy deleted file mode 100644 index 233d5ec94b..0000000000 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy +++ /dev/null @@ -1,1765 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.standard - -import org.apache.commons.dbcp2.DelegatingConnection -import org.apache.nifi.processor.exception.ProcessException -import org.apache.nifi.processor.util.pattern.RollbackOnFailure -import org.apache.nifi.reporting.InitializationException -import org.apache.nifi.serialization.SimpleRecordSchema -import org.apache.nifi.serialization.record.MapRecord -import org.apache.nifi.serialization.record.MockRecordFailureType -import org.apache.nifi.serialization.record.MockRecordParser -import org.apache.nifi.serialization.record.RecordField -import org.apache.nifi.serialization.record.RecordFieldType -import org.apache.nifi.serialization.record.RecordSchema -import org.apache.nifi.util.MockFlowFile -import org.apache.nifi.util.TestRunner -import org.apache.nifi.util.TestRunners -import org.apache.nifi.util.file.FileUtils -import org.junit.AfterClass -import org.junit.Before -import org.junit.BeforeClass -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 - -import java.sql.Blob -import java.sql.Clob -import java.sql.Connection -import java.sql.Date -import java.sql.DriverManager -import java.sql.PreparedStatement -import java.sql.ResultSet -import java.sql.SQLDataException -import java.sql.SQLException -import java.sql.SQLNonTransientConnectionException -import java.sql.Statement -import java.time.LocalDate -import java.time.ZoneOffset -import java.util.function.Supplier - -import static org.junit.Assert.assertEquals -import static org.junit.Assert.assertFalse -import static org.junit.Assert.assertNotNull -import static org.junit.Assert.assertNull -import static org.junit.Assert.assertTrue -import static org.junit.Assert.fail -import static org.mockito.ArgumentMatchers.anyMap -import static org.mockito.Mockito.doAnswer -import static org.mockito.Mockito.spy -import static org.mockito.Mockito.times -import static org.mockito.Mockito.verify - -/** - * Unit tests for the PutDatabaseRecord processor - */ -@RunWith(JUnit4.class) -class TestPutDatabaseRecord { - - private static final String createPersons = "CREATE TABLE PERSONS (id integer primary key, name varchar(100)," + - " code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000), dt date)" - private static final String createPersonsSchema1 = "CREATE TABLE SCHEMA1.PERSONS (id integer primary key, name varchar(100)," + - " code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000), dt date)" - private static final String createPersonsSchema2 = "CREATE TABLE SCHEMA2.PERSONS (id2 integer primary key, name varchar(100)," + - " code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000), dt date)" - private final static String DB_LOCATION = "target/db_pdr" - - TestRunner runner - PutDatabaseRecord processor - DBCPServiceSimpleImpl dbcp - - @BeforeClass - static void setupBeforeClass() throws IOException { - System.setProperty("derby.stream.error.file", "target/derby.log") - - // remove previous test database, if any - final File dbLocation = new File(DB_LOCATION) - try { - FileUtils.deleteFile(dbLocation, true) - } catch (IOException ignore) { - // Do nothing, may not have existed - } - } - - @AfterClass - static void cleanUpAfterClass() throws Exception { - try { - DriverManager.getConnection("jdbc:derby:" + DB_LOCATION + ";shutdown=true") - } catch (SQLNonTransientConnectionException ignore) { - // Do nothing, this is what happens at Derby shutdown - } - // remove previous test database, if any - final File dbLocation = new File(DB_LOCATION) - try { - FileUtils.deleteFile(dbLocation, true) - } catch (IOException ignore) { - // Do nothing, may not have existed - } - } - - @Before - void setUp() throws Exception { - processor = new PutDatabaseRecord() - //Mock the DBCP Controller Service so we can control the Results - dbcp = spy(new DBCPServiceSimpleImpl(DB_LOCATION)) - - final Map dbcpProperties = new HashMap<>() - - runner = TestRunners.newTestRunner(processor) - runner.addControllerService("dbcp", dbcp, dbcpProperties) - runner.enableControllerService(dbcp) - runner.setProperty(PutDatabaseRecord.DBCP_SERVICE, "dbcp") - } - - @Test - void testGeneratePreparedStatements() throws Exception { - - final List fields = [new RecordField('id', RecordFieldType.INT.dataType), - new RecordField('name', RecordFieldType.STRING.dataType), - new RecordField('code', RecordFieldType.INT.dataType), - new RecordField('non_existing', RecordFieldType.BOOLEAN.dataType)] - - def schema = [ - getFields : {fields}, - getFieldCount: {fields.size()}, - getField : {int index -> fields[index]}, - getDataTypes : {fields.collect {it.dataType}}, - getFieldNames: {fields.collect {it.fieldName}}, - getDataType : {fieldName -> fields.find {it.fieldName == fieldName}.dataType} - ] as RecordSchema - - def tableSchema = [ - [ - new PutDatabaseRecord.ColumnDescription('id', 4, true, 2, false), - new PutDatabaseRecord.ColumnDescription('name', 12, true, 255, true), - new PutDatabaseRecord.ColumnDescription('code', 4, true, 10, true) - ], - false, - ['id'] as Set, - '' - ] as PutDatabaseRecord.TableSchema - - runner.setProperty(PutDatabaseRecord.TRANSLATE_FIELD_NAMES, 'false') - runner.setProperty(PutDatabaseRecord.UNMATCHED_FIELD_BEHAVIOR, PutDatabaseRecord.IGNORE_UNMATCHED_FIELD) - runner.setProperty(PutDatabaseRecord.UNMATCHED_COLUMN_BEHAVIOR, PutDatabaseRecord.IGNORE_UNMATCHED_COLUMN) - runner.setProperty(PutDatabaseRecord.QUOTE_IDENTIFIERS, 'false') - runner.setProperty(PutDatabaseRecord.QUOTE_TABLE_IDENTIFIER, 'false') - def settings = new PutDatabaseRecord.DMLSettings(runner.getProcessContext()) - - processor.with { - - assertEquals('INSERT INTO PERSONS (id, name, code) VALUES (?,?,?)', - generateInsert(schema, 'PERSONS', tableSchema, settings).sql) - - assertEquals('UPDATE PERSONS SET name = ?, code = ? WHERE id = ?', - generateUpdate(schema, 'PERSONS', null, tableSchema, settings).sql) - - assertEquals('DELETE FROM PERSONS WHERE (id = ?) AND (name = ? OR (name is null AND ? is null)) AND (code = ? OR (code is null AND ? is null))', - generateDelete(schema, 'PERSONS', tableSchema, settings).sql) - } - } - - @Test - void testGeneratePreparedStatementsFailUnmatchedField() throws Exception { - - final List fields = [new RecordField('id', RecordFieldType.INT.dataType), - new RecordField('name', RecordFieldType.STRING.dataType), - new RecordField('code', RecordFieldType.INT.dataType), - new RecordField('non_existing', RecordFieldType.BOOLEAN.dataType)] - - def schema = [ - getFields : {fields}, - getFieldCount: {fields.size()}, - getField : {int index -> fields[index]}, - getDataTypes : {fields.collect {it.dataType}}, - getFieldNames: {fields.collect {it.fieldName}}, - getDataType : {fieldName -> fields.find {it.fieldName == fieldName}.dataType} - ] as RecordSchema - - def tableSchema = [ - [ - new PutDatabaseRecord.ColumnDescription('id', 4, true, 2, false), - new PutDatabaseRecord.ColumnDescription('name', 12, true, 255, true), - new PutDatabaseRecord.ColumnDescription('code', 4, true, 10, true) - ], - false, - ['id'] as Set, - '' - - ] as PutDatabaseRecord.TableSchema - - runner.setProperty(PutDatabaseRecord.TRANSLATE_FIELD_NAMES, 'false') - runner.setProperty(PutDatabaseRecord.UNMATCHED_FIELD_BEHAVIOR, PutDatabaseRecord.FAIL_UNMATCHED_FIELD) - runner.setProperty(PutDatabaseRecord.UNMATCHED_COLUMN_BEHAVIOR, PutDatabaseRecord.IGNORE_UNMATCHED_COLUMN) - runner.setProperty(PutDatabaseRecord.QUOTE_IDENTIFIERS, 'false') - runner.setProperty(PutDatabaseRecord.QUOTE_TABLE_IDENTIFIER, 'false') - def settings = new PutDatabaseRecord.DMLSettings(runner.getProcessContext()) - - processor.with { - - try { - generateInsert(schema, 'PERSONS', tableSchema, settings) - fail('generateInsert should fail with unmatched fields') - } catch (SQLDataException e) { - assertEquals("Cannot map field 'non_existing' to any column in the database\nColumns: id,name,code", e.getMessage()) - } - - try { - generateUpdate(schema, 'PERSONS', null, tableSchema, settings) - fail('generateUpdate should fail with unmatched fields') - } catch (SQLDataException e) { - assertEquals("Cannot map field 'non_existing' to any column in the database\nColumns: id,name,code", e.getMessage()) - } - - try { - generateDelete(schema, 'PERSONS', tableSchema, settings) - fail('generateDelete should fail with unmatched fields') - } catch (SQLDataException e) { - assertEquals("Cannot map field 'non_existing' to any column in the database\nColumns: id,name,code", e.getMessage()) - } - } - } - - @Test - void testInsert() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - parser.addSchemaField("dt", RecordFieldType.DATE) - - LocalDate testDate1 = LocalDate.of(2021, 1, 26) - Date jdbcDate1 = Date.valueOf(testDate1) // in local TZ - LocalDate testDate2 = LocalDate.of(2021, 7, 26) - Date jdbcDate2 = Date.valueOf(testDate2) // in local TZ - - parser.addRecord(1, 'rec1', 101, jdbcDate1) - parser.addRecord(2, 'rec2', 102, jdbcDate2) - parser.addRecord(3, 'rec3', 103, null) - parser.addRecord(4, 'rec4', 104, null) - parser.addRecord(5, null, 105, null) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(101, rs.getInt(3)) - assertEquals(jdbcDate1.toString(), rs.getDate(4).toString()) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(102, rs.getInt(3)) - assertEquals(jdbcDate2.toString(), rs.getDate(4).toString()) - assertTrue(rs.next()) - assertEquals(3, rs.getInt(1)) - assertEquals('rec3', rs.getString(2)) - assertEquals(103, rs.getInt(3)) - assertNull(rs.getDate(4)) - assertTrue(rs.next()) - assertEquals(4, rs.getInt(1)) - assertEquals('rec4', rs.getString(2)) - assertEquals(104, rs.getInt(3)) - assertNull(rs.getDate(4)) - assertTrue(rs.next()) - assertEquals(5, rs.getInt(1)) - assertNull(rs.getString(2)) - assertEquals(105, rs.getInt(3)) - assertNull(rs.getDate(4)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testInsertNonRequiredColumns() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("dt", RecordFieldType.DATE) - - LocalDate testDate1 = LocalDate.of(2021, 1, 26) - Date jdbcDate1 = Date.valueOf(testDate1) // in local TZ - LocalDate testDate2 = LocalDate.of(2021, 7, 26) - Date jdbcDate2 = Date.valueOf(testDate2) // in local TZ - - parser.addRecord(1, 'rec1', jdbcDate1) - parser.addRecord(2, 'rec2', jdbcDate2) - parser.addRecord(3, 'rec3', null) - parser.addRecord(4, 'rec4', null) - parser.addRecord(5, null, null) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - // Zero value because of the constraint - assertEquals(0, rs.getInt(3)) - assertEquals(jdbcDate1.toString(), rs.getDate(4).toString()) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(0, rs.getInt(3)) - assertEquals(jdbcDate2.toString(), rs.getDate(4).toString()) - assertTrue(rs.next()) - assertEquals(3, rs.getInt(1)) - assertEquals('rec3', rs.getString(2)) - assertEquals(0, rs.getInt(3)) - assertNull(rs.getDate(4)) - assertTrue(rs.next()) - assertEquals(4, rs.getInt(1)) - assertEquals('rec4', rs.getString(2)) - assertEquals(0, rs.getInt(3)) - assertNull(rs.getDate(4)) - assertTrue(rs.next()) - assertEquals(5, rs.getInt(1)) - assertNull(rs.getString(2)) - assertEquals(0, rs.getInt(3)) - assertNull(rs.getDate(4)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testInsertBatchUpdateException() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 101) - parser.addRecord(2, 'rec2', 102) - parser.addRecord(3, 'rec3', 1000) // This record violates the constraint on the 'code' column so should result in FlowFile being routed to failure - parser.addRecord(4, 'rec4', 104) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertAllFlowFilesTransferred(PutDatabaseRecord.REL_FAILURE, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - // Transaction should be rolled back and table should remain empty. - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testInsertBatchUpdateExceptionRollbackOnFailure() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 101) - parser.addRecord(2, 'rec2', 102) - parser.addRecord(3, 'rec3', 1000) - parser.addRecord(4, 'rec4', 104) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - runner.setProperty(RollbackOnFailure.ROLLBACK_ON_FAILURE, 'true') - - runner.enqueue(new byte[0]) - runner.run() - - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - // Transaction should be rolled back and table should remain empty. - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testInsertNoTableSpecified() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 101) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, '${not.a.real.attr}') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0) - runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1) - } - - @Test - void testInsertNoTableExists() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 101) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS2') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0) - runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1) - MockFlowFile flowFile = runner.getFlowFilesForRelationship(PutDatabaseRecord.REL_FAILURE).get(0); - final String errorMessage = flowFile.getAttribute("putdatabaserecord.error") - assertTrue(errorMessage.contains("PERSONS2")) - } - - @Test - void testInsertViaSqlStatementType() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("sql", RecordFieldType.STRING) - - parser.addRecord('''INSERT INTO PERSONS (id, name, code) VALUES (1, 'rec1',101)''') - parser.addRecord('''INSERT INTO PERSONS (id, name, code) VALUES (2, 'rec2',102)''') - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_ATTR_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - runner.setProperty(PutDatabaseRecord.FIELD_CONTAINING_SQL, 'sql') - - def attrs = [:] - attrs[PutDatabaseRecord.STATEMENT_TYPE_ATTRIBUTE] = 'sql' - runner.enqueue(new byte[0], attrs) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(101, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(102, rs.getInt(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testMultipleInsertsViaSqlStatementType() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("sql", RecordFieldType.STRING) - - parser.addRecord('''INSERT INTO PERSONS (id, name, code) VALUES (1, 'rec1',101);INSERT INTO PERSONS (id, name, code) VALUES (2, 'rec2',102)''') - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_ATTR_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - runner.setProperty(PutDatabaseRecord.FIELD_CONTAINING_SQL, 'sql') - runner.setProperty(PutDatabaseRecord.ALLOW_MULTIPLE_STATEMENTS, 'true') - - def attrs = [:] - attrs[PutDatabaseRecord.STATEMENT_TYPE_ATTRIBUTE] = 'sql' - runner.enqueue(new byte[0], attrs) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(101, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(102, rs.getInt(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testMultipleInsertsViaSqlStatementTypeBadSQL() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("sql", RecordFieldType.STRING) - - parser.addRecord('''INSERT INTO PERSONS (id, name, code) VALUES (1, 'rec1',101); - INSERT INTO PERSONS (id, name, code) VALUES (2, 'rec2',102); - INSERT INTO PERSONS2 (id, name, code) VALUES (2, 'rec2',102);''') - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_ATTR_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - runner.setProperty(PutDatabaseRecord.FIELD_CONTAINING_SQL, 'sql') - runner.setProperty(PutDatabaseRecord.ALLOW_MULTIPLE_STATEMENTS, 'true') - - def attrs = [:] - attrs[PutDatabaseRecord.STATEMENT_TYPE_ATTRIBUTE] = 'sql' - runner.enqueue(new byte[0], attrs) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0) - runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - // The first two legitimate statements should have been rolled back - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testInvalidData() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 101) - parser.addRecord(2, 'rec2', 102) - parser.addRecord(3, 'rec3', 104) - - parser.failAfter(1) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertAllFlowFilesTransferred(PutDatabaseRecord.REL_FAILURE, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - // Transaction should be rolled back and table should remain empty. - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testIOExceptionOnReadData() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 101) - parser.addRecord(2, 'rec2', 102) - parser.addRecord(3, 'rec3', 104) - - parser.failAfter(1, MockRecordFailureType.IO_EXCEPTION) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertAllFlowFilesTransferred(PutDatabaseRecord.REL_FAILURE, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - // Transaction should be rolled back and table should remain empty. - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testSqlStatementTypeNoValue() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("sql", RecordFieldType.STRING) - - parser.addRecord('') - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_ATTR_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - runner.setProperty(PutDatabaseRecord.FIELD_CONTAINING_SQL, 'sql') - - def attrs = [:] - attrs[PutDatabaseRecord.STATEMENT_TYPE_ATTRIBUTE] = 'sql' - runner.enqueue(new byte[0], attrs) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0) - runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1) - } - - @Test - void testSqlStatementTypeNoValueRollbackOnFailure() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("sql", RecordFieldType.STRING) - - parser.addRecord('') - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_ATTR_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - runner.setProperty(PutDatabaseRecord.FIELD_CONTAINING_SQL, 'sql') - runner.setProperty(RollbackOnFailure.ROLLBACK_ON_FAILURE, 'true') - - def attrs = [:] - attrs[PutDatabaseRecord.STATEMENT_TYPE_ATTRIBUTE] = 'sql' - runner.enqueue(new byte[0], attrs) - - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0) - runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 0) - } - - @Test - void testUpdate() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 201) - parser.addRecord(2, 'rec2', 202) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - // Set some existing records with different values for name and code - final Connection conn = dbcp.getConnection() - Statement stmt = conn.createStatement() - stmt.execute('''INSERT INTO PERSONS VALUES (1,'x1',101, null)''') - stmt.execute('''INSERT INTO PERSONS VALUES (2,'x2',102, null)''') - stmt.close() - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(201, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(202, rs.getInt(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testUpdatePkNotFirst() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable('CREATE TABLE PERSONS (name varchar(100), id integer primary key, code integer)') - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord('rec1', 1, 201) - parser.addRecord('rec2', 2, 202) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - // Set some existing records with different values for name and code - final Connection conn = dbcp.getConnection() - Statement stmt = conn.createStatement() - stmt.execute('''INSERT INTO PERSONS VALUES ('x1', 1, 101)''') - stmt.execute('''INSERT INTO PERSONS VALUES ('x2', 2, 102)''') - stmt.close() - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals('rec1', rs.getString(1)) - assertEquals(1, rs.getInt(2)) - assertEquals(201, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals('rec2', rs.getString(1)) - assertEquals(2, rs.getInt(2)) - assertEquals(202, rs.getInt(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testUpdateMultipleSchemas() throws InitializationException, ProcessException, SQLException, IOException { - // Manually create and drop the tables and schemas - def conn = dbcp.connection - def stmt = conn.createStatement() - stmt.execute('create schema SCHEMA1') - stmt.execute('create schema SCHEMA2') - stmt.execute(createPersonsSchema1) - stmt.execute(createPersonsSchema2) - - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 201) - parser.addRecord(2, 'rec2', 202) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE) - runner.setProperty(PutDatabaseRecord.SCHEMA_NAME, "SCHEMA1") - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - // Set some existing records with different values for name and code - Exception e - ResultSet rs - try { - stmt.execute('''INSERT INTO SCHEMA1.PERSONS VALUES (1,'x1',101,null)''') - stmt.execute('''INSERT INTO SCHEMA2.PERSONS VALUES (2,'x2',102,null)''') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - rs = stmt.executeQuery('SELECT * FROM SCHEMA1.PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(201, rs.getInt(3)) - assertFalse(rs.next()) - rs = stmt.executeQuery('SELECT * FROM SCHEMA2.PERSONS') - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - // Values should not have been updated - assertEquals('x2', rs.getString(2)) - assertEquals(102, rs.getInt(3)) - assertFalse(rs.next()) - } catch(ex) { - e = ex - } - - // Drop the schemas here so as not to interfere with other tests - stmt.execute("drop table SCHEMA1.PERSONS") - stmt.execute("drop table SCHEMA2.PERSONS") - stmt.execute("drop schema SCHEMA1 RESTRICT") - stmt.execute("drop schema SCHEMA2 RESTRICT") - stmt.close() - - // Don't proceed if there was a problem with the asserts - if(e) throw e - rs = conn.metaData.schemas - List schemas = new ArrayList<>() - while(rs.next()) { - schemas += rs.getString(1) - } - assertFalse(schemas.contains('SCHEMA1')) - assertFalse(schemas.contains('SCHEMA2')) - conn.close() - } - - @Test - void testUpdateAfterInsert() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 101) - parser.addRecord(2, 'rec2', 102) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - final Connection conn = dbcp.getConnection() - Statement stmt = conn.createStatement() - ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(101, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(102, rs.getInt(3)) - assertFalse(rs.next()) - stmt.close() - runner.clearTransferState() - - parser.addRecord(1, 'rec1', 201) - parser.addRecord(2, 'rec2', 202) - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE) - runner.enqueue(new byte[0]) - runner.run(1,true,false) - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - stmt = conn.createStatement() - rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(201, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(202, rs.getInt(3)) - assertFalse(rs.next()) - stmt.close() - conn.close() - } - - @Test - void testUpdateNoPrimaryKeys() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable('CREATE TABLE PERSONS (id integer, name varchar(100), code integer)') - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - parser.addRecord(1, 'rec1', 201) - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0) - runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1) - MockFlowFile flowFile = runner.getFlowFilesForRelationship(PutDatabaseRecord.REL_FAILURE).get(0) - assertEquals('Table \'PERSONS\' not found or does not have a Primary Key and no Update Keys were specified', flowFile.getAttribute(PutDatabaseRecord.PUT_DATABASE_RECORD_ERROR)) - } - - @Test - void testUpdateSpecifyUpdateKeys() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable('CREATE TABLE PERSONS (id integer, name varchar(100), code integer)') - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 201) - parser.addRecord(2, 'rec2', 202) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE) - runner.setProperty(PutDatabaseRecord.UPDATE_KEYS, 'id') - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - // Set some existing records with different values for name and code - final Connection conn = dbcp.getConnection() - Statement stmt = conn.createStatement() - stmt.execute('''INSERT INTO PERSONS VALUES (1,'x1',101)''') - stmt.execute('''INSERT INTO PERSONS VALUES (2,'x2',102)''') - stmt.close() - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(201, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(202, rs.getInt(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testUpdateSpecifyUpdateKeysNotFirst() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable('CREATE TABLE PERSONS (id integer, name varchar(100), code integer)') - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 201) - parser.addRecord(2, 'rec2', 202) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE) - runner.setProperty(PutDatabaseRecord.UPDATE_KEYS, 'code') - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - // Set some existing records with different values for name and code - final Connection conn = dbcp.getConnection() - Statement stmt = conn.createStatement() - stmt.execute('''INSERT INTO PERSONS VALUES (10,'x1',201)''') - stmt.execute('''INSERT INTO PERSONS VALUES (12,'x2',202)''') - stmt.close() - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(201, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(202, rs.getInt(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testUpdateSpecifyQuotedUpdateKeys() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable('CREATE TABLE PERSONS ("id" integer, name varchar(100), code integer)') - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(1, 'rec1', 201) - parser.addRecord(2, 'rec2', 202) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE) - runner.setProperty(PutDatabaseRecord.UPDATE_KEYS, '${updateKey}') - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - runner.setProperty(PutDatabaseRecord.QUOTE_IDENTIFIERS, 'true') - - // Set some existing records with different values for name and code - final Connection conn = dbcp.getConnection() - Statement stmt = conn.createStatement() - stmt.execute('''INSERT INTO PERSONS VALUES (1,'x1',101)''') - stmt.execute('''INSERT INTO PERSONS VALUES (2,'x2',102)''') - stmt.close() - - runner.enqueue(new byte[0], ['updateKey': 'id']) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(201, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(202, rs.getInt(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testDelete() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - Connection conn = dbcp.getConnection() - Statement stmt = conn.createStatement() - stmt.execute("INSERT INTO PERSONS VALUES (1,'rec1', 101, null)") - stmt.execute("INSERT INTO PERSONS VALUES (2,'rec2', 102, null)") - stmt.execute("INSERT INTO PERSONS VALUES (3,'rec3', 103, null)") - stmt.close() - - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(2, 'rec2', 102) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.DELETE_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(101, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(3, rs.getInt(1)) - assertEquals('rec3', rs.getString(2)) - assertEquals(103, rs.getInt(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testDeleteWithNulls() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - Connection conn = dbcp.getConnection() - Statement stmt = conn.createStatement() - stmt.execute("INSERT INTO PERSONS VALUES (1,'rec1', 101, null)") - stmt.execute("INSERT INTO PERSONS VALUES (2,'rec2', null, null)") - stmt.execute("INSERT INTO PERSONS VALUES (3,'rec3', 103, null)") - stmt.close() - - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - parser.addRecord(2, 'rec2', null) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.DELETE_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(101, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(3, rs.getInt(1)) - assertEquals('rec3', rs.getString(2)) - assertEquals(103, rs.getInt(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testRecordPathOptions() { - recreateTable('CREATE TABLE PERSONS (id integer, name varchar(100), code integer)') - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - final List dataFields = new ArrayList<>(); - dataFields.add(new RecordField("id", RecordFieldType.INT.getDataType())) - dataFields.add(new RecordField("name", RecordFieldType.STRING.getDataType())) - dataFields.add(new RecordField("code", RecordFieldType.INT.getDataType())) - - final RecordSchema dataSchema = new SimpleRecordSchema(dataFields) - parser.addSchemaField("operation", RecordFieldType.STRING) - parser.addSchemaField(new RecordField("data", RecordFieldType.RECORD.getRecordDataType(dataSchema))) - - // CREATE, CREATE, CREATE, DELETE, UPDATE - parser.addRecord("INSERT", new MapRecord(dataSchema, ["id": 1, "name": "John Doe", "code": 55] as Map)) - parser.addRecord("INSERT", new MapRecord(dataSchema, ["id": 2, "name": "Jane Doe", "code": 44] as Map)) - parser.addRecord("INSERT", new MapRecord(dataSchema, ["id": 3, "name": "Jim Doe", "code": 2] as Map)) - parser.addRecord("DELETE", new MapRecord(dataSchema, ["id": 2, "name": "Jane Doe", "code": 44] as Map)) - parser.addRecord("UPDATE", new MapRecord(dataSchema, ["id": 1, "name": "John Doe", "code": 201] as Map)) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_RECORD_PATH) - runner.setProperty(PutDatabaseRecord.DATA_RECORD_PATH, "/data") - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE_RECORD_PATH, "/operation") - runner.setProperty(PutDatabaseRecord.UPDATE_KEYS, 'id') - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertAllFlowFilesTransferred(PutDatabaseRecord.REL_SUCCESS, 1) - - Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('John Doe', rs.getString(2)) - assertEquals(201, rs.getInt(3)) - assertTrue(rs.next()) - assertEquals(3, rs.getInt(1)) - assertEquals('Jim Doe', rs.getString(2)) - assertEquals(2, rs.getInt(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testInsertWithMaxBatchSize() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - (1..11).each { - parser.addRecord(it, "rec$it".toString(), 100 + it) - } - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - runner.setProperty(PutDatabaseRecord.MAX_BATCH_SIZE, "5") - - Supplier spyStmt = createPreparedStatementSpy() - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - - assertEquals(11, getTableSize()) - - assertNotNull(spyStmt.get()) - verify(spyStmt.get(), times(3)).executeBatch() - } - - @Test - void testInsertWithDefaultMaxBatchSize() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - - (1..11).each { - parser.addRecord(it, "rec$it".toString(), 100 + it) - } - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - Supplier spyStmt = createPreparedStatementSpy() - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - - assertEquals(11, getTableSize()) - - assertNotNull(spyStmt.get()) - verify(spyStmt.get(), times(1)).executeBatch() - } - - private Supplier createPreparedStatementSpy() { - PreparedStatement spyStmt - doAnswer({ inv -> - new DelegatingConnection((Connection)inv.callRealMethod()) { - @Override - PreparedStatement prepareStatement(String sql) throws SQLException { - spyStmt = spy(getDelegate().prepareStatement(sql)) - } - } - }).when(dbcp).getConnection(anyMap()) - return { spyStmt } - } - - private int getTableSize() { - final Connection connection = dbcp.getConnection() - try { - final Statement stmt = connection.createStatement() - try { - final ResultSet rs = stmt.executeQuery('SELECT count(*) FROM PERSONS') - assertTrue(rs.next()) - rs.getInt(1) - } finally { - stmt.close() - } - } finally { - connection.close() - } - } - - private void recreateTable(String createSQL) throws ProcessException, SQLException { - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - try { - stmt.execute("drop table PERSONS") - } catch (SQLException ignore) { - // Do nothing, may not have existed - } - stmt.execute(createSQL) - stmt.close() - conn.close() - } - - @Test - void testGenerateTableName() throws Exception { - - final List fields = [new RecordField('id', RecordFieldType.INT.dataType), - new RecordField('name', RecordFieldType.STRING.dataType), - new RecordField('code', RecordFieldType.INT.dataType), - new RecordField('non_existing', RecordFieldType.BOOLEAN.dataType)] - - def schema = [ - getFields : {fields}, - getFieldCount: {fields.size()}, - getField : {int index -> fields[index]}, - getDataTypes : {fields.collect {it.dataType}}, - getFieldNames: {fields.collect {it.fieldName}}, - getDataType : {fieldName -> fields.find {it.fieldName == fieldName}.dataType} - ] as RecordSchema - - def tableSchema = [ - [ - new PutDatabaseRecord.ColumnDescription('id', 4, true, 2, false), - new PutDatabaseRecord.ColumnDescription('name', 12, true, 255, true), - new PutDatabaseRecord.ColumnDescription('code', 4, true, 10, true) - ], - false, - ['id'] as Set, - '"' - - ] as PutDatabaseRecord.TableSchema - - runner.setProperty(PutDatabaseRecord.TRANSLATE_FIELD_NAMES, 'false') - runner.setProperty(PutDatabaseRecord.UNMATCHED_FIELD_BEHAVIOR, PutDatabaseRecord.IGNORE_UNMATCHED_FIELD) - runner.setProperty(PutDatabaseRecord.UNMATCHED_COLUMN_BEHAVIOR, PutDatabaseRecord.IGNORE_UNMATCHED_COLUMN) - runner.setProperty(PutDatabaseRecord.QUOTE_IDENTIFIERS, 'true') - runner.setProperty(PutDatabaseRecord.QUOTE_TABLE_IDENTIFIER, 'true') - def settings = new PutDatabaseRecord.DMLSettings(runner.getProcessContext()) - - processor.with { - - assertEquals('"test_catalog"."test_schema"."test_table"', - generateTableName(settings,"test_catalog","test_schema","test_table",tableSchema)) - - } - } - - @Test - void testInsertMismatchedCompatibleDataTypes() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - parser.addSchemaField("dt", RecordFieldType.BIGINT) - - LocalDate testDate1 = LocalDate.of(2021, 1, 26) - Date jdbcDate1 = Date.valueOf(testDate1) // in local TZ - BigInteger nifiDate1 = jdbcDate1.getTime() // in local TZ - - LocalDate testDate2 = LocalDate.of(2021, 7, 26) - Date jdbcDate2 = Date.valueOf(testDate2) // in local TZ - BigInteger nifiDate2 = jdbcDate2.getTime() // in local TZ - - parser.addRecord(1, 'rec1', 101, nifiDate1) - parser.addRecord(2, 'rec2', 102, nifiDate2) - parser.addRecord(3, 'rec3', 103, null) - parser.addRecord(4, 'rec4', 104, null) - parser.addRecord(5, null, 105, null) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertEquals(101, rs.getInt(3)) - assertEquals(jdbcDate1.toString(), rs.getDate(4).toString()) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertEquals(102, rs.getInt(3)) - assertEquals(jdbcDate2.toString(), rs.getDate(4).toString()) - assertTrue(rs.next()) - assertEquals(3, rs.getInt(1)) - assertEquals('rec3', rs.getString(2)) - assertEquals(103, rs.getInt(3)) - assertNull(rs.getDate(4)) - assertTrue(rs.next()) - assertEquals(4, rs.getInt(1)) - assertEquals('rec4', rs.getString(2)) - assertEquals(104, rs.getInt(3)) - assertNull(rs.getDate(4)) - assertTrue(rs.next()) - assertEquals(5, rs.getInt(1)) - assertNull(rs.getString(2)) - assertEquals(105, rs.getInt(3)) - assertNull(rs.getDate(4)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - - @Test - void testInsertMismatchedNotCompatibleDataTypes() throws InitializationException, ProcessException, SQLException, IOException { - recreateTable(createPersons) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.STRING) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - parser.addSchemaField("dt", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.FLOAT.getDataType()).getFieldType()); - - LocalDate testDate1 = LocalDate.of(2021, 1, 26) - BigInteger nifiDate1 = testDate1.atStartOfDay(ZoneOffset.UTC).toInstant().toEpochMilli() // in UTC - Date jdbcDate1 = Date.valueOf(testDate1) // in local TZ - LocalDate testDate2 = LocalDate.of(2021, 7, 26) - BigInteger nifiDate2 = testDate2.atStartOfDay(ZoneOffset.UTC).toInstant().toEpochMilli() // in UTC - Date jdbcDate2 = Date.valueOf(testDate2) // in local TZ - - parser.addRecord('1', 'rec1', 101, [1.0,2.0]) - parser.addRecord('2', 'rec2', 102, [3.0,4.0]) - parser.addRecord('3', 'rec3', 103, null) - parser.addRecord('4', 'rec4', 104, null) - parser.addRecord('5', null, 105, null) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - // A SQLFeatureNotSupportedException exception is expected from Derby when you try to put the data as an ARRAY - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0) - runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1) - } - - @Test - void testLongVarchar() throws InitializationException, ProcessException, SQLException, IOException { - // Manually create and drop the tables and schemas - def conn = dbcp.connection - def stmt = conn.createStatement() - try { - stmt.execute('DROP TABLE TEMP') - } catch(ex) { - // Do nothing, table may not exist - } - stmt.execute('CREATE TABLE TEMP (id integer primary key, name long varchar)') - - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - - parser.addRecord(1, 'rec1') - parser.addRecord(2, 'rec2') - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'TEMP') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - ResultSet rs = stmt.executeQuery('SELECT * FROM TEMP') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals('rec1', rs.getString(2)) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals('rec2', rs.getString(2)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testInsertWithDifferentColumnOrdering() throws InitializationException, ProcessException, SQLException, IOException { - // Manually create and drop the tables and schemas - def conn = dbcp.connection - def stmt = conn.createStatement() - try { - stmt.execute('DROP TABLE TEMP') - } catch(ex) { - // Do nothing, table may not exist - } - stmt.execute('CREATE TABLE TEMP (id integer primary key, code integer, name long varchar)') - - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("code", RecordFieldType.INT) - - // change order of columns - parser.addRecord('rec1', 1, 101) - parser.addRecord('rec2', 2, 102) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'TEMP') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - ResultSet rs = stmt.executeQuery('SELECT * FROM TEMP') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - assertEquals(101, rs.getInt(2)) - assertEquals('rec1', rs.getString(3)) - assertTrue(rs.next()) - assertEquals(2, rs.getInt(1)) - assertEquals(102, rs.getInt(2)) - assertEquals('rec2', rs.getString(3)) - assertFalse(rs.next()) - - stmt.close() - conn.close() - } - - @Test - void testInsertWithBlobClob() throws Exception { - String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + - "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))" - - recreateTable(createTableWithBlob) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - byte[] bytes = "BLOB".getBytes() - Byte[] blobRecordValue = new Byte[bytes.length] - (0 .. (bytes.length-1)).each { i -> blobRecordValue[i] = bytes[i].longValue() } - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - parser.addSchemaField("content", RecordFieldType.ARRAY) - - parser.addRecord(1, 'rec1', 101, blobRecordValue) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - Clob clob = rs.getClob(2) - assertNotNull(clob) - char[] clobText = new char[5] - int numBytes = clob.characterStream.read(clobText) - assertEquals(4, numBytes) - // Ignore last character, it's meant to ensure that only 4 bytes were read even though the buffer is 5 bytes - assertEquals('rec1', new String(clobText).substring(0,4)) - Blob blob = rs.getBlob(3) - assertEquals("BLOB", new String(blob.getBytes(1, blob.length() as int))) - assertEquals(101, rs.getInt(4)) - - stmt.close() - conn.close() - } - - @Test - void testInsertWithBlobClobObjectArraySource() throws Exception { - String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + - "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))" - - recreateTable(createTableWithBlob) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - byte[] bytes = "BLOB".getBytes() - Object[] blobRecordValue = new Object[bytes.length] - (0 .. (bytes.length-1)).each { i -> blobRecordValue[i] = bytes[i] } - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - parser.addSchemaField("content", RecordFieldType.ARRAY) - - parser.addRecord(1, 'rec1', 101, blobRecordValue) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - Clob clob = rs.getClob(2) - assertNotNull(clob) - char[] clobText = new char[5] - int numBytes = clob.characterStream.read(clobText) - assertEquals(4, numBytes) - // Ignore last character, it's meant to ensure that only 4 bytes were read even though the buffer is 5 bytes - assertEquals('rec1', new String(clobText).substring(0,4)) - Blob blob = rs.getBlob(3) - assertEquals("BLOB", new String(blob.getBytes(1, blob.length() as int))) - assertEquals(101, rs.getInt(4)) - - stmt.close() - conn.close() - } - - @Test - void testInsertWithBlobStringSource() throws Exception { - String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + - "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))" - - recreateTable(createTableWithBlob) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - parser.addSchemaField("content", RecordFieldType.STRING) - - parser.addRecord(1, 'rec1', 101, 'BLOB') - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) - final Connection conn = dbcp.getConnection() - final Statement stmt = conn.createStatement() - final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') - assertTrue(rs.next()) - assertEquals(1, rs.getInt(1)) - Clob clob = rs.getClob(2) - assertNotNull(clob) - char[] clobText = new char[5] - int numBytes = clob.characterStream.read(clobText) - assertEquals(4, numBytes) - // Ignore last character, it's meant to ensure that only 4 bytes were read even though the buffer is 5 bytes - assertEquals('rec1', new String(clobText).substring(0,4)) - Blob blob = rs.getBlob(3) - assertEquals("BLOB", new String(blob.getBytes(1, blob.length() as int))) - assertEquals(101, rs.getInt(4)) - - stmt.close() - conn.close() - } - - @Test - void testInsertWithBlobIntegerArraySource() throws Exception { - String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + - "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))" - - recreateTable(createTableWithBlob) - final MockRecordParser parser = new MockRecordParser() - runner.addControllerService("parser", parser) - runner.enableControllerService(parser) - - parser.addSchemaField("id", RecordFieldType.INT) - parser.addSchemaField("name", RecordFieldType.STRING) - parser.addSchemaField("code", RecordFieldType.INT) - parser.addSchemaField("content", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.INT.getDataType()).getFieldType()) - - parser.addRecord(1, 'rec1', 101, [1,2,3] as Integer[]) - - runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') - runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) - runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') - - runner.enqueue(new byte[0]) - runner.run() - - runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0) - runner.assertTransferCount(PutDatabaseRecord.REL_RETRY, 0) - runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1) - } -} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/CountTextTest.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/CountTextTest.java new file mode 100644 index 0000000000..08014cd968 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/CountTextTest.java @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License") you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.standard; + +import org.apache.nifi.components.PropertyDescriptor; +import org.apache.nifi.util.MockComponentLog; +import org.apache.nifi.util.MockFlowFile; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; + +public class CountTextTest { + private static final String TLC = "text.line.count"; + private static final String TLNEC = "text.line.nonempty.count"; + private static final String TWC = "text.word.count"; + private static final String TCC = "text.character.count"; + private TestRunner runner; + + @BeforeEach + void setupRunner() { + runner = TestRunners.newTestRunner(CountText.class); + } + + + @Test + void testShouldCountAllMetrics() throws IOException { + runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "true"); + + final Path inputPath = Paths.get("src/test/resources/TestCountText/jabberwocky.txt"); + + final Map expectedValues = new HashMap<>(); + expectedValues.put(TLC, "34"); + expectedValues.put(TLNEC, "28"); + expectedValues.put(TWC, "166"); + expectedValues.put(TCC, "900"); + + runner.enqueue(Files.readAllBytes(inputPath)); + + runner.run(); + + runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1); + MockFlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).get(0); + for (final Map.Entry entry: expectedValues.entrySet()) { + final String attribute = entry.getKey(); + final String expectedValue = entry.getValue(); + flowFile.assertAttributeEquals(attribute, expectedValue); + } + } + + @Test + void testShouldCountEachMetric() throws IOException { + final Path inputPath = Paths.get("src/test/resources/TestCountText/jabberwocky.txt"); + + final Map expectedValues = new HashMap<>(); + expectedValues.put(TLC, "34"); + expectedValues.put(TLNEC, "28"); + expectedValues.put(TWC, "166"); + expectedValues.put(TCC, "900"); + + final Map linesOnly = Collections.singletonMap(CountText.TEXT_LINE_COUNT_PD, "true"); + final Map linesNonEmptyOnly = Collections.singletonMap(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "true"); + final Map wordsOnly = Collections.singletonMap(CountText.TEXT_WORD_COUNT_PD, "true"); + final Map charactersOnly = Collections.singletonMap(CountText.TEXT_CHARACTER_COUNT_PD, "true"); + + final List> scenarios = Arrays.asList(linesOnly, linesNonEmptyOnly, wordsOnly, charactersOnly); + + for (final Map map: scenarios) { + // Reset the processor properties + runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "false"); + runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "false"); + runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "false"); + runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "false"); + + // Apply the scenario-specific properties + for (final Map.Entry entry: map.entrySet()) { + runner.setProperty(entry.getKey(), entry.getValue()); + } + + runner.clearProvenanceEvents(); + runner.clearTransferState(); + runner.enqueue(Files.readAllBytes(inputPath)); + + runner.run(); + + runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1); + MockFlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).get(0); + for (final Map.Entry entry: expectedValues.entrySet()) { + final String attribute = entry.getKey(); + final String expectedValue = entry.getValue(); + + if (flowFile.getAttributes().containsKey(attribute)) { + flowFile.assertAttributeEquals(attribute, expectedValue); + } + } + } + } + + @Test + void testShouldCountWordsSplitOnSymbol() throws IOException { + final Path inputPath = Paths.get("src/test/resources/TestCountText/jabberwocky.txt"); + + final String EXPECTED_WORD_COUNT = "167"; + + // Reset the processor properties + runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "false"); + runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "false"); + runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "false"); + runner.setProperty(CountText.SPLIT_WORDS_ON_SYMBOLS_PD, "true"); + + runner.clearProvenanceEvents(); + runner.clearTransferState(); + runner.enqueue(Files.readAllBytes(inputPath)); + + runner.run(); + + runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1); + MockFlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).get(0); + flowFile.assertAttributeEquals(CountText.TEXT_WORD_COUNT, EXPECTED_WORD_COUNT); + } + + @Test + void testShouldCountIndependentlyPerFlowFile() throws IOException { + final Path inputPath = Paths.get("src/test/resources/TestCountText/jabberwocky.txt"); + + final Map expectedValues = new HashMap<>(); + expectedValues.put(TLC, "34"); + expectedValues.put(TLNEC, "28"); + expectedValues.put(TWC, "166"); + expectedValues.put(TCC, "900"); + + // Reset the processor properties + runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "true"); + + for (int i = 0; i < 2; i++) { + runner.clearProvenanceEvents(); + runner.clearTransferState(); + runner.enqueue(Files.readAllBytes(inputPath)); + + runner.run(); + + runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1); + MockFlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).get(0); + for (final Map.Entry entry: expectedValues.entrySet()) { + final String attribute = entry.getKey(); + final String expectedValue = entry.getValue(); + + flowFile.assertAttributeEquals(attribute, expectedValue); + } + } + } + + @Test + void testShouldTrackSessionCountersAcrossMultipleFlowfiles() throws IOException, NoSuchFieldException, IllegalAccessException { + final Path inputPath = Paths.get("src/test/resources/TestCountText/jabberwocky.txt"); + + final Map expectedValues = new HashMap<>(); + expectedValues.put(TLC, "34"); + expectedValues.put(TLNEC, "28"); + expectedValues.put(TWC, "166"); + expectedValues.put(TCC, "900"); + + // Reset the processor properties + runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "true"); + + final int n = 2; + for (int i = 0; i < n; i++) { + runner.clearTransferState(); + runner.enqueue(Files.readAllBytes(inputPath)); + + runner.run(); + + runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1); + MockFlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).get(0); + for (final Map.Entry entry: expectedValues.entrySet()) { + final String attribute = entry.getKey(); + final String expectedValue = entry.getValue(); + + flowFile.assertAttributeEquals(attribute, expectedValue); + } + } + + assertEquals(Long.valueOf(expectedValues.get(TLC)) * n, runner.getCounterValue("Lines Counted")); + assertEquals(Long.valueOf(expectedValues.get(TLNEC)) * n, runner.getCounterValue("Lines (non-empty) Counted")); + assertEquals(Long.valueOf(expectedValues.get(TWC)) * n, runner.getCounterValue("Words Counted")); + assertEquals(Long.valueOf(expectedValues.get(TCC)) * n, runner.getCounterValue("Characters Counted")); + } + + @Test + void testShouldHandleInternalError() { + CountText ct = new CountText() { + @Override + int countWordsInLine(String line, boolean splitWordsOnSymbols) throws IOException { + throw new IOException("Expected exception"); + } + }; + + final TestRunner runner = TestRunners.newTestRunner(ct); + final String INPUT_TEXT = "This flowfile should throw an error"; + + // Reset the processor properties + runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "true"); + runner.setProperty(CountText.CHARACTER_ENCODING_PD, StandardCharsets.US_ASCII.displayName()); + + runner.enqueue(INPUT_TEXT.getBytes()); + + // Need initialize = true to run #onScheduled() + runner.run(1, true, true); + + runner.assertAllFlowFilesTransferred(CountText.REL_FAILURE, 1); + } + + @Test + void testShouldIgnoreWhitespaceWordsWhenCounting() { + final String INPUT_TEXT = "a b c"; + + final String EXPECTED_WORD_COUNT = "3"; + + // Reset the processor properties + runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "false"); + runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "false"); + runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "false"); + runner.setProperty(CountText.SPLIT_WORDS_ON_SYMBOLS_PD, "true"); + + runner.clearProvenanceEvents(); + runner.clearTransferState(); + runner.enqueue(INPUT_TEXT.getBytes()); + + runner.run(); + + runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1); + MockFlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).get(0); + flowFile.assertAttributeEquals(CountText.TEXT_WORD_COUNT, EXPECTED_WORD_COUNT); + } + + @Test + void testShouldIgnoreWhitespaceWordsWhenCountingDebugMode() { + final MockComponentLog componentLogger = spy(new MockComponentLog("processorId", new CountText())); + doReturn(true).when(componentLogger).isDebugEnabled(); + final TestRunner runner = TestRunners.newTestRunner(CountText.class, componentLogger); + final String INPUT_TEXT = "a b c"; + + final String EXPECTED_WORD_COUNT = "3"; + + // Reset the processor properties + runner.setProperty(CountText.TEXT_LINE_COUNT_PD, "false"); + runner.setProperty(CountText.TEXT_LINE_NONEMPTY_COUNT_PD, "false"); + runner.setProperty(CountText.TEXT_WORD_COUNT_PD, "true"); + runner.setProperty(CountText.TEXT_CHARACTER_COUNT_PD, "false"); + runner.setProperty(CountText.SPLIT_WORDS_ON_SYMBOLS_PD, "true"); + + runner.clearProvenanceEvents(); + runner.clearTransferState(); + runner.enqueue(INPUT_TEXT.getBytes()); + + runner.run(); + + runner.assertAllFlowFilesTransferred(CountText.REL_SUCCESS, 1); + MockFlowFile flowFile = runner.getFlowFilesForRelationship(CountText.REL_SUCCESS).get(0); + + flowFile.assertAttributeEquals(CountText.TEXT_WORD_COUNT, EXPECTED_WORD_COUNT); + } + +} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/CryptographicHashAttributeTest.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/CryptographicHashAttributeTest.java new file mode 100644 index 0000000000..e82ca4b281 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/CryptographicHashAttributeTest.java @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License") you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.standard; + +import org.apache.nifi.security.util.crypto.HashAlgorithm; +import org.apache.nifi.security.util.crypto.HashService; +import org.apache.nifi.util.MockFlowFile; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.security.Security; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class CryptographicHashAttributeTest { + private TestRunner runner; + + @BeforeAll + static void setUpOnce() { + Security.addProvider(new BouncyCastleProvider()); + } + + @BeforeEach + void setupRunner() { + runner = TestRunners.newTestRunner(new CryptographicHashAttribute()); + } + + @Test + void testShouldCalculateHashOfPresentAttribute() { + // Create attributes for username and date + final Map attributes = new HashMap<>(); + attributes.put("username", "alopresto"); + attributes.put("date", ZonedDateTime.now().format(DateTimeFormatter.ofPattern("YYYY-MM-dd HH:mm:ss.SSS Z"))); + + final Set attributeKeys = attributes.keySet(); + + for (final HashAlgorithm algorithm : HashAlgorithm.values()) { + final String expectedUsernameHash = HashService.hashValue(algorithm, attributes.get("username")); + final String expectedDateHash = HashService.hashValue(algorithm, attributes.get("date")); + + // Reset the processor + runner.clearProperties(); + runner.clearProvenanceEvents(); + runner.clearTransferState(); + + // Set the algorithm + runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.getName()); + + // Add the desired dynamic properties + for (final String attr: attributeKeys) { + runner.setProperty(attr, String.format("%s_%s", attr, algorithm.getName())); + } + + // Insert the attributes in the mock flowfile + runner.enqueue(new byte[0], attributes); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 0); + runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 1); + + final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_SUCCESS); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = successfulFlowfiles.get(0); + + flowFile.assertAttributeEquals(String.format("username_%s", algorithm.getName()), expectedUsernameHash); + flowFile.assertAttributeEquals(String.format("date_%s", algorithm.getName()), expectedDateHash); + } + } + + @Test + void testShouldCalculateHashOfMissingAttribute() { + // Create attributes for username (empty string) and date (null) + final Map attributes = new HashMap<>(); + attributes.put("username", ""); + attributes.put("date", null); + + final Set attributeKeys = attributes.keySet(); + + for (final HashAlgorithm algorithm: HashAlgorithm.values()) { + final String expectedUsernameHash = HashService.hashValue(algorithm, attributes.get("username")); + final String expectedDateHash = null; + + // Reset the processor + runner.clearProperties(); + runner.clearProvenanceEvents(); + runner.clearTransferState(); + + // Set the algorithm + runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.getName()); + + // Add the desired dynamic properties + for (final String attr: attributeKeys) { + runner.setProperty(attr, String.format("%s_%s", attr, algorithm.getName())); + } + + // Insert the attributes in the mock flowfile + runner.enqueue(new byte[0], attributes); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 0); + runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 1); + + final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_SUCCESS); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = successfulFlowfiles.get(0); + + flowFile.assertAttributeEquals(String.format("username_%s", algorithm.getName()), expectedUsernameHash); + flowFile.assertAttributeEquals(String.format("date_%s", algorithm.getName()), expectedDateHash); + } + } + + @Test + void testShouldRouteToFailureOnProhibitedMissingAttribute() { + // Create attributes for username (empty string) and date (null) + final Map attributes = new HashMap<>(); + attributes.put("username", ""); + attributes.put("date", null); + + final Set attributeKeys = attributes.keySet(); + + for (final HashAlgorithm algorithm: HashAlgorithm.values()) { + // Reset the processor + runner.clearProperties(); + runner.clearProvenanceEvents(); + runner.clearTransferState(); + + // Set the algorithm + runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.getName()); + + // Set to fail if there are missing attributes + runner.setProperty(CryptographicHashAttribute.PARTIAL_ATTR_ROUTE_POLICY, CryptographicHashAttribute.PartialAttributePolicy.PROHIBIT.name()); + + // Add the desired dynamic properties + for (final String attr: attributeKeys) { + runner.setProperty(attr, String.format("%s_%s", attr, algorithm.getName())); + } + + // Insert the attributes in the mock flowfile + runner.enqueue(new byte[0], attributes); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 1); + runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 0); + + final List failedFlowFiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_FAILURE); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = failedFlowFiles.get(0); + for (final String missingAttribute: attributeKeys) { + flowFile.assertAttributeNotExists(String.format("%s_%s", missingAttribute, algorithm.getName())); + } + } + } + + @Test + void testShouldRouteToFailureOnEmptyAttributes() { + // Create attributes for username (empty string) and date (null) + final Map attributes = new HashMap<>(); + attributes.put("username", ""); + attributes.put("date", null); + + final Set attributeKeys = attributes.keySet(); + + for (final HashAlgorithm algorithm: HashAlgorithm.values()) { + // Reset the processor + runner.clearProperties(); + runner.clearProvenanceEvents(); + runner.clearTransferState(); + + // Set the algorithm + runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.getName()); + + // Set to fail if all attributes are missing + runner.setProperty(CryptographicHashAttribute.FAIL_WHEN_EMPTY, "true"); + + // Insert the attributes in the mock flowfile + runner.enqueue(new byte[0], attributes); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 1); + runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 0); + + final List failedFlowFiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_FAILURE); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = failedFlowFiles.get(0); + for (final String missingAttribute: attributeKeys) { + flowFile.assertAttributeNotExists(String.format("%s_%s", missingAttribute, algorithm.getName())); + } + } + } + + @Test + void testShouldRouteToSuccessOnAllowPartial() { + // Create attributes for username (empty string) and date (null) + final Map attributes = new HashMap<>(); + attributes.put("username", ""); + + final Set attributeKeys = attributes.keySet(); + + for (final HashAlgorithm algorithm: HashAlgorithm.values()) { + // Reset the processor + runner.clearProperties(); + runner.clearProvenanceEvents(); + runner.clearTransferState(); + + // Set the algorithm + runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.getName()); + + // Set to fail if there are missing attributes + runner.setProperty(CryptographicHashAttribute.PARTIAL_ATTR_ROUTE_POLICY, CryptographicHashAttribute.PartialAttributePolicy.ALLOW.name()); + + // Add the desired dynamic properties + for (final String attr: attributeKeys) { + runner.setProperty(attr, String.format("%s_%s", attr, algorithm.getName())); + } + + // Insert the attributes in the mock flowfile + runner.enqueue(new byte[0], attributes); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 0); + runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 1); + + final List successfulFlowFiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_SUCCESS); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = successfulFlowFiles.get(0); + for (final String attribute: attributeKeys) { + flowFile.assertAttributeExists(String.format("%s_%s", attribute, algorithm.getName())); + } + } + } + + @Test + void testShouldCalculateHashWithVariousCharacterEncodings() { + // Create attributes + final Map attributes = new HashMap<>(); + attributes.put("test_attribute", "apachenifi"); + final Set attributeKeys = attributes.keySet(); + + final HashAlgorithm algorithm = HashAlgorithm.MD5; + + final List charsets = Arrays.asList(StandardCharsets.UTF_8, StandardCharsets.UTF_16, StandardCharsets.UTF_16LE, StandardCharsets.UTF_16BE); + + final Map EXPECTED_MD5_HASHES = new HashMap<>(); + EXPECTED_MD5_HASHES.put(StandardCharsets.UTF_8.name(), "a968b5ec1d52449963dcc517789baaaf"); + EXPECTED_MD5_HASHES.put(StandardCharsets.UTF_16.name(), "b8413d18f7e64042bb0322a1cd61eba2"); + EXPECTED_MD5_HASHES.put(StandardCharsets.UTF_16BE.name(), "b8413d18f7e64042bb0322a1cd61eba2"); + EXPECTED_MD5_HASHES.put(StandardCharsets.UTF_16LE.name(), "91c3b67f9f8ae77156f21f271cc09121"); + + for (final Charset charset: charsets) { + // Calculate the expected hash value given the character set + final String EXPECTED_HASH = HashService.hashValue(algorithm, attributes.get("test_attribute"), charset); + + // Sanity check + assertEquals(EXPECTED_HASH, EXPECTED_MD5_HASHES.get(charset.name())); + + // Reset the processor + runner.clearProperties(); + runner.clearProvenanceEvents(); + runner.clearTransferState(); + + // Set the properties + runner.setProperty(CryptographicHashAttribute.HASH_ALGORITHM, algorithm.getName()); + runner.setProperty(CryptographicHashAttribute.CHARACTER_SET, charset.name()); + + // Add the desired dynamic properties + for (final String attr: attributeKeys) { + runner.setProperty(attr, String.format("%s_%s", attr, algorithm.getName())); + } + + // Insert the attributes in the mock flowfile + runner.enqueue(new byte[0], attributes); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashAttribute.REL_FAILURE, 0); + runner.assertTransferCount(CryptographicHashAttribute.REL_SUCCESS, 1); + + final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashAttribute.REL_SUCCESS); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = successfulFlowfiles.get(0); + + flowFile.assertAttributeEquals(String.format("test_attribute_%s", algorithm.getName()), EXPECTED_HASH); + } + } +} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/CryptographicHashContentTest.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/CryptographicHashContentTest.java new file mode 100644 index 0000000000..5573c79372 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/CryptographicHashContentTest.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License") you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.standard; + +import org.apache.commons.lang3.StringUtils; +import org.apache.nifi.security.util.crypto.HashAlgorithm; +import org.apache.nifi.security.util.crypto.HashService; +import org.apache.nifi.util.MockFlowFile; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.security.Security; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +public class CryptographicHashContentTest { + private TestRunner runner; + + @BeforeAll + static void setUpOnce() { + Security.addProvider(new BouncyCastleProvider()); + } + + @BeforeEach + void setupRunner() { + runner = TestRunners.newTestRunner(new CryptographicHashContent()); + } + + @Test + void testShouldCalculateHashOfPresentContent() throws IOException { + // Generate some long content (90 KB) + final String longContent = StringUtils.repeat("apachenifi ", 8192); + + for (final HashAlgorithm algorithm : HashAlgorithm.values()) { + final String expectedContentHash = HashService.hashValueStreaming(algorithm, new ByteArrayInputStream(longContent.getBytes())); + + // Reset the processor + runner.clearProperties(); + runner.clearProvenanceEvents(); + runner.clearTransferState(); + + // Set the algorithm + runner.setProperty(CryptographicHashContent.HASH_ALGORITHM, algorithm.getName()); + + // Insert the content in the mock flowfile + runner.enqueue(longContent.getBytes(StandardCharsets.UTF_8), + Collections.singletonMap("size", String.valueOf(longContent.length()))); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashContent.REL_FAILURE, 0); + runner.assertTransferCount(CryptographicHashContent.REL_SUCCESS, 1); + + final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashContent.REL_SUCCESS); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = successfulFlowfiles.get(0); + String hashAttribute = String.format("content_%s", algorithm.getName()); + flowFile.assertAttributeExists(hashAttribute); + flowFile.assertAttributeEquals(hashAttribute, expectedContentHash); + } + } + + @Test + void testShouldCalculateHashOfEmptyContent() throws IOException { + final String emptyContent = ""; + + for (final HashAlgorithm algorithm : HashAlgorithm.values()) { + final String expectedContentHash = HashService.hashValueStreaming(algorithm, new ByteArrayInputStream(emptyContent.getBytes())); + + // Reset the processor + runner.clearProperties(); + runner.clearProvenanceEvents(); + runner.clearTransferState(); + + // Set the algorithm + runner.setProperty(CryptographicHashContent.HASH_ALGORITHM, algorithm.getName()); + + // Insert the content in the mock flowfile + runner.enqueue(emptyContent.getBytes(StandardCharsets.UTF_8), Collections.singletonMap("size", "0")); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashContent.REL_FAILURE, 0); + runner.assertTransferCount(CryptographicHashContent.REL_SUCCESS, 1); + + final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashContent.REL_SUCCESS); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = successfulFlowfiles.get(0); + String hashAttribute = String.format("content_%s", algorithm.getName()); + flowFile.assertAttributeExists(hashAttribute); + + String hashedContent = flowFile.getAttribute(hashAttribute); + + assertEquals(expectedContentHash, hashedContent); + } + } + + /** + * This test works because {@link MockFlowFile} uses the actual internal {@code data.size} for {@code getSize ( )}, + * while {@code StandardFlowFileRecord} uses a separate {@code size} field. May need to use {@code flowfile.getContentClaim ( ) .getLength ( )}. + */ + @Test + void testShouldCalculateHashOfContentWithIncorrectSizeAttribute() throws IOException { + final String nonEmptyContent = "apachenifi"; + + final TestRunner runner = TestRunners.newTestRunner(new CryptographicHashContent()); + + for (final HashAlgorithm algorithm : HashAlgorithm.values()) { + final String expectedContentHash = HashService.hashValueStreaming(algorithm, new ByteArrayInputStream(nonEmptyContent.getBytes())); + + // Reset the processor + runner.clearProperties(); + runner.clearProvenanceEvents(); + runner.clearTransferState(); + + // Set the algorithm + runner.setProperty(CryptographicHashContent.HASH_ALGORITHM, algorithm.getName()); + + // Insert the content in the mock flowfile (with the wrong size attribute) + runner.enqueue(nonEmptyContent.getBytes(StandardCharsets.UTF_8), Collections.singletonMap("size", "0")); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashContent.REL_FAILURE, 0); + runner.assertTransferCount(CryptographicHashContent.REL_SUCCESS, 1); + + final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashContent.REL_SUCCESS); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = successfulFlowfiles.get(0); + String hashAttribute = String.format("content_%s", algorithm.getName()); + flowFile.assertAttributeExists(hashAttribute); + flowFile.assertAttributeEquals(hashAttribute, expectedContentHash); + } + } + + @Test + void testShouldOverwriteExistingAttribute() { + final String nonEmptyContent = "apachenifi"; + final String oldHashAttributeValue = "OLD VALUE"; + + HashAlgorithm algorithm = HashAlgorithm.SHA256; + + final String expectedContentHash = HashService.hashValue(algorithm, nonEmptyContent); + + // Set the algorithm + runner.setProperty(CryptographicHashContent.HASH_ALGORITHM, algorithm.getName()); + + // Insert the content in the mock flowfile (with an existing attribute) + final Map oldAttributes = Collections.singletonMap(String.format("content_%s", algorithm.getName()), + oldHashAttributeValue); + runner.enqueue(nonEmptyContent.getBytes(StandardCharsets.UTF_8), + oldAttributes); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashContent.REL_FAILURE, 0); + runner.assertTransferCount(CryptographicHashContent.REL_SUCCESS, 1); + + final List successfulFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashContent.REL_SUCCESS); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = successfulFlowfiles.get(0); + String hashAttribute = String.format("content_%s", algorithm.getName()); + flowFile.assertAttributeExists(hashAttribute); + + String hashedContent = flowFile.getAttribute(hashAttribute); + + assertNotEquals(oldHashAttributeValue, hashedContent); + assertEquals(expectedContentHash, hashedContent); + } + + @Test + void testShouldRouteToFailureOnEmptyContent() { + final String emptyContent = ""; + + for (final HashAlgorithm algorithm : HashAlgorithm.values()) { + // Reset the processor + runner.clearProperties(); + runner.clearProvenanceEvents(); + runner.clearTransferState(); + + // Set the failure property + runner.setProperty(CryptographicHashContent.FAIL_WHEN_EMPTY, "true"); + + // Set the algorithm + runner.setProperty(CryptographicHashContent.HASH_ALGORITHM, algorithm.getName()); + + // Insert the content in the mock flowfile + runner.enqueue(emptyContent.getBytes(StandardCharsets.UTF_8)); + + runner.run(1); + + runner.assertTransferCount(CryptographicHashContent.REL_FAILURE, 1); + runner.assertTransferCount(CryptographicHashContent.REL_SUCCESS, 0); + + final List failedFlowfiles = runner.getFlowFilesForRelationship(CryptographicHashContent.REL_FAILURE); + + // Extract the generated attributes from the flowfile + MockFlowFile flowFile = failedFlowfiles.get(0); + String hashAttribute = String.format("content_%s", algorithm.getName()); + flowFile.assertAttributeNotExists(hashAttribute); + } + } +} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/PutDatabaseRecordTest.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/PutDatabaseRecordTest.java index d0d3e1d186..dfef1eacdb 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/PutDatabaseRecordTest.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/PutDatabaseRecordTest.java @@ -16,11 +16,19 @@ */ package org.apache.nifi.processors.standard; +import org.apache.commons.dbcp2.DelegatingConnection; import org.apache.nifi.processor.exception.ProcessException; +import org.apache.nifi.processor.util.pattern.RollbackOnFailure; import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.serialization.MalformedRecordException; +import org.apache.nifi.serialization.SimpleRecordSchema; +import org.apache.nifi.serialization.record.MapRecord; +import org.apache.nifi.serialization.record.MockRecordFailureType; import org.apache.nifi.serialization.record.MockRecordParser; +import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordFieldType; import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.util.MockFlowFile; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; import org.apache.nifi.util.file.FileUtils; @@ -28,23 +36,45 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.stubbing.Answer; import java.io.File; import java.io.IOException; +import java.math.BigInteger; +import java.sql.Blob; +import java.sql.Clob; import java.sql.Connection; import java.sql.Date; import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLDataException; import java.sql.SQLException; import java.sql.SQLNonTransientConnectionException; import java.sql.Statement; import java.time.LocalDate; import java.time.ZoneOffset; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.function.Supplier; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class PutDatabaseRecordTest { @@ -57,6 +87,11 @@ public class PutDatabaseRecordTest { private static final String createPersons = "CREATE TABLE PERSONS (id integer primary key, name varchar(100)," + " code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000), dt date)"; + private static final String createPersonsSchema1 = "CREATE TABLE SCHEMA1.PERSONS (id integer primary key, name varchar(100)," + + " code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000), dt date)"; + private static final String createPersonsSchema2 = "CREATE TABLE SCHEMA2.PERSONS (id2 integer primary key, name varchar(100)," + + " code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000), dt date)"; + private final static String DB_LOCATION = "target/db_pdr"; TestRunner runner; @@ -172,6 +207,1558 @@ public class PutDatabaseRecordTest { runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1); } + @Test + void testGeneratePreparedStatements() throws SQLException, MalformedRecordException { + + final List fields = Arrays.asList(new RecordField("id", RecordFieldType.INT.getDataType()), + new RecordField("name", RecordFieldType.STRING.getDataType()), + new RecordField("code", RecordFieldType.INT.getDataType()), + new RecordField("non_existing", RecordFieldType.BOOLEAN.getDataType())); + final RecordSchema schema = new SimpleRecordSchema(fields); + + final PutDatabaseRecord.TableSchema tableSchema = new PutDatabaseRecord.TableSchema( + Arrays.asList( + new PutDatabaseRecord.ColumnDescription("id", 4, true, 2, false), + new PutDatabaseRecord.ColumnDescription("name", 12, true, 255, true), + new PutDatabaseRecord.ColumnDescription("code", 4, true, 10, true) + ), + false, + new HashSet(Arrays.asList("id")), + "" + ); + + runner.setProperty(PutDatabaseRecord.TRANSLATE_FIELD_NAMES, "false"); + runner.setProperty(PutDatabaseRecord.UNMATCHED_FIELD_BEHAVIOR, PutDatabaseRecord.IGNORE_UNMATCHED_FIELD); + runner.setProperty(PutDatabaseRecord.UNMATCHED_COLUMN_BEHAVIOR, PutDatabaseRecord.IGNORE_UNMATCHED_COLUMN); + runner.setProperty(PutDatabaseRecord.QUOTE_IDENTIFIERS, "false"); + runner.setProperty(PutDatabaseRecord.QUOTE_TABLE_IDENTIFIER, "false"); + final PutDatabaseRecord.DMLSettings settings = new PutDatabaseRecord.DMLSettings(runner.getProcessContext()); + + assertEquals("INSERT INTO PERSONS (id, name, code) VALUES (?,?,?)", + processor.generateInsert(schema, "PERSONS", tableSchema, settings).getSql()); + assertEquals("UPDATE PERSONS SET name = ?, code = ? WHERE id = ?", + processor.generateUpdate(schema, "PERSONS", null, tableSchema, settings).getSql()); + assertEquals("DELETE FROM PERSONS WHERE (id = ?) AND (name = ? OR (name is null AND ? is null)) AND (code = ? OR (code is null AND ? is null))", + processor.generateDelete(schema, "PERSONS", tableSchema, settings).getSql()); + } + + @Test + void testGeneratePreparedStatementsFailUnmatchedField() throws SQLException, MalformedRecordException { + + final List fields = Arrays.asList(new RecordField("id", RecordFieldType.INT.getDataType()), + new RecordField("name", RecordFieldType.STRING.getDataType()), + new RecordField("code", RecordFieldType.INT.getDataType()), + new RecordField("non_existing", RecordFieldType.BOOLEAN.getDataType())); + final RecordSchema schema = new SimpleRecordSchema(fields); + + final PutDatabaseRecord.TableSchema tableSchema = new PutDatabaseRecord.TableSchema( + Arrays.asList( + new PutDatabaseRecord.ColumnDescription("id", 4, true, 2, false), + new PutDatabaseRecord.ColumnDescription("name", 12, true, 255, true), + new PutDatabaseRecord.ColumnDescription("code", 4, true, 10, true) + ), + false, + new HashSet(Arrays.asList("id")), + "" + ); + + runner.setProperty(PutDatabaseRecord.TRANSLATE_FIELD_NAMES, "false"); + runner.setProperty(PutDatabaseRecord.UNMATCHED_FIELD_BEHAVIOR, PutDatabaseRecord.FAIL_UNMATCHED_FIELD); + runner.setProperty(PutDatabaseRecord.UNMATCHED_COLUMN_BEHAVIOR, PutDatabaseRecord.IGNORE_UNMATCHED_COLUMN); + runner.setProperty(PutDatabaseRecord.QUOTE_IDENTIFIERS, "false"); + runner.setProperty(PutDatabaseRecord.QUOTE_TABLE_IDENTIFIER, "false"); + final PutDatabaseRecord.DMLSettings settings = new PutDatabaseRecord.DMLSettings(runner.getProcessContext()); + + SQLDataException e = assertThrows(SQLDataException.class, + () -> processor.generateInsert(schema, "PERSONS", tableSchema, settings), + "generateInsert should fail with unmatched fields"); + assertEquals("Cannot map field 'non_existing' to any column in the database\nColumns: id,name,code", e.getMessage()); + + e = assertThrows(SQLDataException.class, + () -> processor.generateUpdate(schema, "PERSONS", null, tableSchema, settings), + "generateUpdate should fail with unmatched fields"); + assertEquals("Cannot map field 'non_existing' to any column in the database\nColumns: id,name,code", e.getMessage()); + + e = assertThrows(SQLDataException.class, + () -> processor.generateDelete(schema, "PERSONS", tableSchema, settings), + "generateDelete should fail with unmatched fields"); + assertEquals("Cannot map field 'non_existing' to any column in the database\nColumns: id,name,code", e.getMessage()); + } + + @Test + void testInsert() throws InitializationException, ProcessException, SQLException, IOException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + parser.addSchemaField("dt", RecordFieldType.DATE); + + LocalDate testDate1 = LocalDate.of(2021, 1, 26); + Date jdbcDate1 = Date.valueOf(testDate1); // in local TZ + LocalDate testDate2 = LocalDate.of(2021, 7, 26); + Date jdbcDate2 = Date.valueOf(testDate2); // in local TZ + + parser.addRecord(1, "rec1", 101, jdbcDate1); + parser.addRecord(2, "rec2", 102, jdbcDate2); + parser.addRecord(3, "rec3", 103, null); + parser.addRecord(4, "rec4", 104, null); + parser.addRecord(5, null, 105, null); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(101, rs.getInt(3)); + assertEquals(jdbcDate1.toString(), rs.getDate(4).toString()); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(102, rs.getInt(3)); + assertEquals(jdbcDate2.toString(), rs.getDate(4).toString()); + assertTrue(rs.next()); + assertEquals(3, rs.getInt(1)); + assertEquals("rec3", rs.getString(2)); + assertEquals(103, rs.getInt(3)); + assertNull(rs.getDate(4)); + assertTrue(rs.next()); + assertEquals(4, rs.getInt(1)); + assertEquals("rec4", rs.getString(2)); + assertEquals(104, rs.getInt(3)); + assertNull(rs.getDate(4)); + assertTrue(rs.next()); + assertEquals(5, rs.getInt(1)); + assertNull(rs.getString(2)); + assertEquals(105, rs.getInt(3)); + assertNull(rs.getDate(4)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertNonRequiredColumns() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("dt", RecordFieldType.DATE); + + LocalDate testDate1 = LocalDate.of(2021, 1, 26); + Date jdbcDate1 = Date.valueOf(testDate1); // in local TZ + LocalDate testDate2 = LocalDate.of(2021, 7, 26); + Date jdbcDate2 = Date.valueOf(testDate2); // in local TZ + + parser.addRecord(1, "rec1", jdbcDate1); + parser.addRecord(2, "rec2", jdbcDate2); + parser.addRecord(3, "rec3", null); + parser.addRecord(4, "rec4", null); + parser.addRecord(5, null, null); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + // Zero value because of the constraint + assertEquals(0, rs.getInt(3)); + assertEquals(jdbcDate1.toString(), rs.getDate(4).toString()); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(0, rs.getInt(3)); + assertEquals(jdbcDate2.toString(), rs.getDate(4).toString()); + assertTrue(rs.next()); + assertEquals(3, rs.getInt(1)); + assertEquals("rec3", rs.getString(2)); + assertEquals(0, rs.getInt(3)); + assertNull(rs.getDate(4)); + assertTrue(rs.next()); + assertEquals(4, rs.getInt(1)); + assertEquals("rec4", rs.getString(2)); + assertEquals(0, rs.getInt(3)); + assertNull(rs.getDate(4)); + assertTrue(rs.next()); + assertEquals(5, rs.getInt(1)); + assertNull(rs.getString(2)); + assertEquals(0, rs.getInt(3)); + assertNull(rs.getDate(4)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertBatchUpdateException() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 101); + parser.addRecord(2, "rec2", 102); + parser.addRecord(3, "rec3", 1000); // This record violates the constraint on the "code" column so should result in FlowFile being routed to failure + parser.addRecord(4, "rec4", 104); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertAllFlowFilesTransferred(PutDatabaseRecord.REL_FAILURE, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + // Transaction should be rolled back and table should remain empty. + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertBatchUpdateExceptionRollbackOnFailure() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 101); + parser.addRecord(2, "rec2", 102); + parser.addRecord(3, "rec3", 1000); + parser.addRecord(4, "rec4", 104); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + runner.setProperty(RollbackOnFailure.ROLLBACK_ON_FAILURE, "true"); + + runner.enqueue(new byte[0]); + runner.run(); + + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + // Transaction should be rolled back and table should remain empty. + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertNoTableSpecified() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 101); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "${not.a.real.attr}"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0); + runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1); + } + + @Test + void testInsertNoTableExists() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 101); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS2"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0); + runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1); + MockFlowFile flowFile = runner.getFlowFilesForRelationship(PutDatabaseRecord.REL_FAILURE).get(0); + final String errorMessage = flowFile.getAttribute("putdatabaserecord.error"); + assertTrue(errorMessage.contains("PERSONS2")); + } + + @Test + void testInsertViaSqlStatementType() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("sql", RecordFieldType.STRING); + + parser.addRecord("INSERT INTO PERSONS (id, name, code) VALUES (1, 'rec1',101)"); + parser.addRecord("INSERT INTO PERSONS (id, name, code) VALUES (2, 'rec2',102)"); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_ATTR_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + runner.setProperty(PutDatabaseRecord.FIELD_CONTAINING_SQL, "sql"); + + final Map attrs = new HashMap<>(); + attrs.put(PutDatabaseRecord.STATEMENT_TYPE_ATTRIBUTE, "sql"); + runner.enqueue(new byte[0], attrs); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(101, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(102, rs.getInt(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testMultipleInsertsViaSqlStatementType() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("sql", RecordFieldType.STRING); + + parser.addRecord("INSERT INTO PERSONS (id, name, code) VALUES (1, 'rec1',101);INSERT INTO PERSONS (id, name, code) VALUES (2, 'rec2',102)"); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_ATTR_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + runner.setProperty(PutDatabaseRecord.FIELD_CONTAINING_SQL, "sql"); + runner.setProperty(PutDatabaseRecord.ALLOW_MULTIPLE_STATEMENTS, "true"); + + final Map attrs = new HashMap<>(); + attrs.put(PutDatabaseRecord.STATEMENT_TYPE_ATTRIBUTE, "sql"); + runner.enqueue(new byte[0], attrs); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(101, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(102, rs.getInt(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testMultipleInsertsViaSqlStatementTypeBadSQL() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("sql", RecordFieldType.STRING); + + parser.addRecord("INSERT INTO PERSONS (id, name, code) VALUES (1, 'rec1',101);" + + "INSERT INTO PERSONS (id, name, code) VALUES (2, 'rec2',102);" + + "INSERT INTO PERSONS2 (id, name, code) VALUES (2, 'rec2',102);"); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_ATTR_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + runner.setProperty(PutDatabaseRecord.FIELD_CONTAINING_SQL, "sql"); + runner.setProperty(PutDatabaseRecord.ALLOW_MULTIPLE_STATEMENTS, "true"); + + final Map attrs = new HashMap<>(); + attrs.put(PutDatabaseRecord.STATEMENT_TYPE_ATTRIBUTE, "sql"); + runner.enqueue(new byte[0], attrs); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0); + runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + // The first two legitimate statements should have been rolled back + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testInvalidData() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 101); + parser.addRecord(2, "rec2", 102); + parser.addRecord(3, "rec3", 104); + + parser.failAfter(1); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertAllFlowFilesTransferred(PutDatabaseRecord.REL_FAILURE, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + // Transaction should be rolled back and table should remain empty. + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testIOExceptionOnReadData() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 101); + parser.addRecord(2, "rec2", 102); + parser.addRecord(3, "rec3", 104); + + parser.failAfter(1, MockRecordFailureType.IO_EXCEPTION); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertAllFlowFilesTransferred(PutDatabaseRecord.REL_FAILURE, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + // Transaction should be rolled back and table should remain empty. + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testSqlStatementTypeNoValue() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("sql", RecordFieldType.STRING); + + parser.addRecord(""); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_ATTR_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + runner.setProperty(PutDatabaseRecord.FIELD_CONTAINING_SQL, "sql"); + + final Map attrs = new HashMap<>(); + attrs.put(PutDatabaseRecord.STATEMENT_TYPE_ATTRIBUTE, "sql"); + runner.enqueue(new byte[0], attrs); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0); + runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1); + } + + @Test + void testSqlStatementTypeNoValueRollbackOnFailure() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("sql", RecordFieldType.STRING); + + parser.addRecord(""); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_ATTR_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + runner.setProperty(PutDatabaseRecord.FIELD_CONTAINING_SQL, "sql"); + runner.setProperty(RollbackOnFailure.ROLLBACK_ON_FAILURE, "true"); + + final Map attrs = new HashMap<>(); + attrs.put(PutDatabaseRecord.STATEMENT_TYPE_ATTRIBUTE, "sql"); + runner.enqueue(new byte[0], attrs); + + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0); + runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 0); + } + + @Test + void testUpdate() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 201); + parser.addRecord(2, "rec2", 202); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + // Set some existing records with different values for name and code + final Connection conn = dbcp.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("INSERT INTO PERSONS VALUES (1,'x1',101, null)"); + stmt.execute("INSERT INTO PERSONS VALUES (2,'x2',102, null)"); + stmt.close(); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(201, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(202, rs.getInt(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testUpdatePkNotFirst() throws InitializationException, ProcessException, SQLException { + recreateTable("CREATE TABLE PERSONS (name varchar(100), id integer primary key, code integer)"); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord("rec1", 1, 201); + parser.addRecord("rec2", 2, 202); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + // Set some existing records with different values for name and code + final Connection conn = dbcp.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("INSERT INTO PERSONS VALUES ('x1', 1, 101)"); + stmt.execute("INSERT INTO PERSONS VALUES ('x2', 2, 102)"); + stmt.close(); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals("rec1", rs.getString(1)); + assertEquals(1, rs.getInt(2)); + assertEquals(201, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals("rec2", rs.getString(1)); + assertEquals(2, rs.getInt(2)); + assertEquals(202, rs.getInt(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testUpdateMultipleSchemas() throws InitializationException, ProcessException, SQLException { + // Manually create and drop the tables and schemas + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + stmt.execute("create schema SCHEMA1"); + stmt.execute("create schema SCHEMA2"); + stmt.execute(createPersonsSchema1); + stmt.execute(createPersonsSchema2); + + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 201); + parser.addRecord(2, "rec2", 202); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE); + runner.setProperty(PutDatabaseRecord.SCHEMA_NAME, "SCHEMA1"); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + // Set some existing records with different values for name and code + Exception e; + ResultSet rs; + + stmt.execute("INSERT INTO SCHEMA1.PERSONS VALUES (1,'x1',101,null)"); + stmt.execute("INSERT INTO SCHEMA2.PERSONS VALUES (2,'x2',102,null)"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + rs = stmt.executeQuery("SELECT * FROM SCHEMA1.PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(201, rs.getInt(3)); + assertFalse(rs.next()); + rs = stmt.executeQuery("SELECT * FROM SCHEMA2.PERSONS"); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + // Values should not have been updated + assertEquals("x2", rs.getString(2)); + assertEquals(102, rs.getInt(3)); + assertFalse(rs.next()); + + // Drop the schemas here so as not to interfere with other tests + stmt.execute("drop table SCHEMA1.PERSONS"); + stmt.execute("drop table SCHEMA2.PERSONS"); + stmt.execute("drop schema SCHEMA1 RESTRICT"); + stmt.execute("drop schema SCHEMA2 RESTRICT"); + stmt.close(); + + // Don't proceed if there was a problem with the asserts + rs = conn.getMetaData().getSchemas(); + List schemas = new ArrayList<>(); + while(rs.next()) { + schemas.add(rs.getString(1)); + } + assertFalse(schemas.contains("SCHEMA1")); + assertFalse(schemas.contains("SCHEMA2")); + conn.close(); + } + + @Test + void testUpdateAfterInsert() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 101); + parser.addRecord(2, "rec2", 102); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + final Connection conn = dbcp.getConnection(); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(101, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(102, rs.getInt(3)); + assertFalse(rs.next()); + stmt.close(); + runner.clearTransferState(); + + parser.addRecord(1, "rec1", 201); + parser.addRecord(2, "rec2", 202); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE); + runner.enqueue(new byte[0]); + runner.run(1, true, false); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + stmt = conn.createStatement(); + rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(201, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(202, rs.getInt(3)); + assertFalse(rs.next()); + stmt.close(); + conn.close(); + } + + @Test + void testUpdateNoPrimaryKeys() throws InitializationException, ProcessException, SQLException { + recreateTable("CREATE TABLE PERSONS (id integer, name varchar(100), code integer)"); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + parser.addRecord(1, "rec1", 201); + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0); + runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1); + MockFlowFile flowFile = runner.getFlowFilesForRelationship(PutDatabaseRecord.REL_FAILURE).get(0); + assertEquals("Table 'PERSONS' not found or does not have a Primary Key and no Update Keys were specified", flowFile.getAttribute(PutDatabaseRecord.PUT_DATABASE_RECORD_ERROR)); + } + + @Test + void testUpdateSpecifyUpdateKeys() throws InitializationException, ProcessException, SQLException { + recreateTable("CREATE TABLE PERSONS (id integer, name varchar(100), code integer)"); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 201); + parser.addRecord(2, "rec2", 202); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE); + runner.setProperty(PutDatabaseRecord.UPDATE_KEYS, "id"); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + // Set some existing records with different values for name and code + final Connection conn = dbcp.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("INSERT INTO PERSONS VALUES (1,'x1',101)"); + stmt.execute("INSERT INTO PERSONS VALUES (2,'x2',102)"); + stmt.close(); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(201, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(202, rs.getInt(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testUpdateSpecifyUpdateKeysNotFirst() throws InitializationException, ProcessException, SQLException { + recreateTable("CREATE TABLE PERSONS (id integer, name varchar(100), code integer)"); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 201); + parser.addRecord(2, "rec2", 202); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE); + runner.setProperty(PutDatabaseRecord.UPDATE_KEYS, "code"); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + // Set some existing records with different values for name and code + final Connection conn = dbcp.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("INSERT INTO PERSONS VALUES (10,'x1',201)"); + stmt.execute("INSERT INTO PERSONS VALUES (12,'x2',202)"); + stmt.close(); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(201, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(202, rs.getInt(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testUpdateSpecifyQuotedUpdateKeys() throws InitializationException, ProcessException, SQLException { + recreateTable("CREATE TABLE PERSONS (\"id\" integer, name varchar(100), code integer)"); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(1, "rec1", 201); + parser.addRecord(2, "rec2", 202); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.UPDATE_TYPE); + runner.setProperty(PutDatabaseRecord.UPDATE_KEYS, "${updateKey}"); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + runner.setProperty(PutDatabaseRecord.QUOTE_IDENTIFIERS, "true"); + + // Set some existing records with different values for name and code + final Connection conn = dbcp.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("INSERT INTO PERSONS VALUES (1,'x1',101)"); + stmt.execute("INSERT INTO PERSONS VALUES (2,'x2',102)"); + stmt.close(); + + runner.enqueue(new byte[0], Collections.singletonMap("updateKey", "id")); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(201, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(202, rs.getInt(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testDelete() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + Connection conn = dbcp.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("INSERT INTO PERSONS VALUES (1, 'rec1', 101, null)"); + stmt.execute("INSERT INTO PERSONS VALUES (2, 'rec2', 102, null)"); + stmt.execute("INSERT INTO PERSONS VALUES (3, 'rec3', 103, null)"); + stmt.close(); + + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(2, "rec2", 102); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.DELETE_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(101, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(3, rs.getInt(1)); + assertEquals("rec3", rs.getString(2)); + assertEquals(103, rs.getInt(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testDeleteWithNulls() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + Connection conn = dbcp.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("INSERT INTO PERSONS VALUES (1, 'rec1', 101, null)"); + stmt.execute("INSERT INTO PERSONS VALUES (2, 'rec2', null, null)"); + stmt.execute("INSERT INTO PERSONS VALUES (3, 'rec3', 103, null)"); + stmt.close(); + + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + parser.addRecord(2, "rec2", null); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.DELETE_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(101, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(3, rs.getInt(1)); + assertEquals("rec3", rs.getString(2)); + assertEquals(103, rs.getInt(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testRecordPathOptions() throws InitializationException, SQLException { + recreateTable("CREATE TABLE PERSONS (id integer, name varchar(100), code integer)"); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + final List dataFields = new ArrayList<>(); + dataFields.add(new RecordField("id", RecordFieldType.INT.getDataType())); + dataFields.add(new RecordField("name", RecordFieldType.STRING.getDataType())); + dataFields.add(new RecordField("code", RecordFieldType.INT.getDataType())); + + final RecordSchema dataSchema = new SimpleRecordSchema(dataFields); + parser.addSchemaField("operation", RecordFieldType.STRING); + parser.addSchemaField(new RecordField("data", RecordFieldType.RECORD.getRecordDataType(dataSchema))); + + // CREATE, CREATE, CREATE, DELETE, UPDATE + parser.addRecord("INSERT", new MapRecord(dataSchema, createValues(1, "John Doe", 55))); + parser.addRecord("INSERT", new MapRecord(dataSchema, createValues(2, "Jane Doe", 44))); + parser.addRecord("INSERT", new MapRecord(dataSchema, createValues(3, "Jim Doe", 2))); + parser.addRecord("DELETE", new MapRecord(dataSchema, createValues(2, "Jane Doe", 44))); + parser.addRecord("UPDATE", new MapRecord(dataSchema, createValues(1, "John Doe", 201))); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.USE_RECORD_PATH); + runner.setProperty(PutDatabaseRecord.DATA_RECORD_PATH, "/data"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE_RECORD_PATH, "/operation"); + runner.setProperty(PutDatabaseRecord.UPDATE_KEYS, "id"); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertAllFlowFilesTransferred(PutDatabaseRecord.REL_SUCCESS, 1); + + Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("John Doe", rs.getString(2)); + assertEquals(201, rs.getInt(3)); + assertTrue(rs.next()); + assertEquals(3, rs.getInt(1)); + assertEquals("Jim Doe", rs.getString(2)); + assertEquals(2, rs.getInt(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertWithMaxBatchSize() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + for (int i = 1; i < 12; i++) { + parser.addRecord(i, String.format("rec%s", i), 100 + i); + } + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + runner.setProperty(PutDatabaseRecord.MAX_BATCH_SIZE, "5"); + + Supplier spyStmt = createPreparedStatementSpy(); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + + assertEquals(11, getTableSize()); + + assertNotNull(spyStmt.get()); + verify(spyStmt.get(), times(3)).executeBatch(); + } + + @Test + void testInsertWithDefaultMaxBatchSize() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + + for (int i = 1; i < 12; i++) { + parser.addRecord(i, String.format("rec%s", i), 100 + i); + } + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + Supplier spyStmt = createPreparedStatementSpy(); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + + assertEquals(11, getTableSize()); + + assertNotNull(spyStmt.get()); + verify(spyStmt.get(), times(1)).executeBatch(); + } + + @Test + void testGenerateTableName() throws Exception { + final List fields = Arrays.asList(new RecordField("id", RecordFieldType.INT.getDataType()), + new RecordField("name", RecordFieldType.STRING.getDataType()), + new RecordField("code", RecordFieldType.INT.getDataType()), + new RecordField("non_existing", RecordFieldType.BOOLEAN.getDataType()) + ); + + final RecordSchema schema = new SimpleRecordSchema(fields); + + final PutDatabaseRecord.TableSchema tableSchema = new PutDatabaseRecord.TableSchema( + Arrays.asList( + new PutDatabaseRecord.ColumnDescription("id", 4, true, 2, false), + new PutDatabaseRecord.ColumnDescription("name", 12, true, 255, true), + new PutDatabaseRecord.ColumnDescription("code", 4, true, 10, true) + ), + false, + new HashSet<>(Arrays.asList("id")), + "" + ); + + runner.setProperty(PutDatabaseRecord.TRANSLATE_FIELD_NAMES, "false"); + runner.setProperty(PutDatabaseRecord.UNMATCHED_FIELD_BEHAVIOR, PutDatabaseRecord.IGNORE_UNMATCHED_FIELD); + runner.setProperty(PutDatabaseRecord.UNMATCHED_COLUMN_BEHAVIOR, PutDatabaseRecord.IGNORE_UNMATCHED_COLUMN); + runner.setProperty(PutDatabaseRecord.QUOTE_IDENTIFIERS, "true"); + runner.setProperty(PutDatabaseRecord.QUOTE_TABLE_IDENTIFIER, "true"); + final PutDatabaseRecord.DMLSettings settings = new PutDatabaseRecord.DMLSettings(runner.getProcessContext()); + + + assertEquals("test_catalog.test_schema.test_table", + processor.generateTableName(settings, "test_catalog", "test_schema", "test_table", tableSchema)); + } + + @Test + void testInsertMismatchedCompatibleDataTypes() throws InitializationException, ProcessException, SQLException, IOException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + parser.addSchemaField("dt", RecordFieldType.BIGINT); + + LocalDate testDate1 = LocalDate.of(2021, 1, 26); + Date jdbcDate1 = Date.valueOf(testDate1); // in local TZ + BigInteger nifiDate1 = BigInteger.valueOf(jdbcDate1.getTime()); // in local TZ + + LocalDate testDate2 = LocalDate.of(2021, 7, 26); + Date jdbcDate2 = Date.valueOf(testDate2); // in local TZ + BigInteger nifiDate2 = BigInteger.valueOf(jdbcDate2.getTime()); // in local TZ + + parser.addRecord(1, "rec1", 101, nifiDate1); + parser.addRecord(2, "rec2", 102, nifiDate2); + parser.addRecord(3, "rec3", 103, null); + parser.addRecord(4, "rec4", 104, null); + parser.addRecord(5, null, 105, null); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertEquals(101, rs.getInt(3)); + assertEquals(jdbcDate1.toString(), rs.getDate(4).toString()); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertEquals(102, rs.getInt(3)); + assertEquals(jdbcDate2.toString(), rs.getDate(4).toString()); + assertTrue(rs.next()); + assertEquals(3, rs.getInt(1)); + assertEquals("rec3", rs.getString(2)); + assertEquals(103, rs.getInt(3)); + assertNull(rs.getDate(4)); + assertTrue(rs.next()); + assertEquals(4, rs.getInt(1)); + assertEquals("rec4", rs.getString(2)); + assertEquals(104, rs.getInt(3)); + assertNull(rs.getDate(4)); + assertTrue(rs.next()); + assertEquals(5, rs.getInt(1)); + assertNull(rs.getString(2)); + assertEquals(105, rs.getInt(3)); + assertNull(rs.getDate(4)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertMismatchedNotCompatibleDataTypes() throws InitializationException, ProcessException, SQLException { + recreateTable(createPersons); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.STRING); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + parser.addSchemaField("dt", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.FLOAT.getDataType()).getFieldType()); + + LocalDate testDate1 = LocalDate.of(2021, 1, 26); + BigInteger nifiDate1 = BigInteger.valueOf(testDate1.atStartOfDay(ZoneOffset.UTC).toInstant().toEpochMilli()); // in UTC + Date jdbcDate1 = Date.valueOf(testDate1); // in local TZ + LocalDate testDate2 = LocalDate.of(2021, 7, 26); + BigInteger nifiDate2 = BigInteger.valueOf(testDate2.atStartOfDay(ZoneOffset.UTC).toInstant().toEpochMilli()); // in UTC + Date jdbcDate2 = Date.valueOf(testDate2); // in local TZ + + parser.addRecord("1", "rec1", 101, Arrays.asList(1.0, 2.0)); + parser.addRecord("2", "rec2", 102, Arrays.asList(3.0, 4.0)); + parser.addRecord("3", "rec3", 103, null); + parser.addRecord("4", "rec4", 104, null); + parser.addRecord("5", null, 105, null); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + // A SQLFeatureNotSupportedException exception is expected from Derby when you try to put the data as an ARRAY + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0); + runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1); + } + + @Test + void testLongVarchar() throws InitializationException, ProcessException, SQLException { + // Manually create and drop the tables and schemas + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + try { + stmt.execute("DROP TABLE TEMP"); + } catch(final Exception e) { + // Do nothing, table may not exist + } + stmt.execute("CREATE TABLE TEMP (id integer primary key, name long varchar)"); + + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + + parser.addRecord(1, "rec1"); + parser.addRecord(2, "rec2"); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "TEMP"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + ResultSet rs = stmt.executeQuery("SELECT * FROM TEMP"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals("rec1", rs.getString(2)); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals("rec2", rs.getString(2)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertWithDifferentColumnOrdering() throws InitializationException, ProcessException, SQLException, IOException { + // Manually create and drop the tables and schemas + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + try { + stmt.execute("DROP TABLE TEMP"); + } catch(final Exception e) { + // Do nothing, table may not exist + } + stmt.execute("CREATE TABLE TEMP (id integer primary key, code integer, name long varchar)"); + + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("code", RecordFieldType.INT); + + // change order of columns + parser.addRecord("rec1", 1, 101); + parser.addRecord("rec2", 2, 102); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "TEMP"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + ResultSet rs = stmt.executeQuery("SELECT * FROM TEMP"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + assertEquals(101, rs.getInt(2)); + assertEquals("rec1", rs.getString(3)); + assertTrue(rs.next()); + assertEquals(2, rs.getInt(1)); + assertEquals(102, rs.getInt(2)); + assertEquals("rec2", rs.getString(3)); + assertFalse(rs.next()); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertWithBlobClob() throws Exception { + String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + + "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))"; + + recreateTable(createTableWithBlob); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + byte[] bytes = "BLOB".getBytes(); + Byte[] blobRecordValue = new Byte[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + blobRecordValue[i] = Byte.valueOf(bytes[i]); + } + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + parser.addSchemaField("content", RecordFieldType.ARRAY); + + parser.addRecord(1, "rec1", 101, blobRecordValue); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + Clob clob = rs.getClob(2); + assertNotNull(clob); + char[] clobText = new char[5]; + int numBytes = clob.getCharacterStream().read(clobText); + assertEquals(4, numBytes); + // Ignore last character, it"s meant to ensure that only 4 bytes were read even though the buffer is 5 bytes + assertEquals("rec1", new String(clobText).substring(0, 4)); + Blob blob = rs.getBlob(3); + assertEquals("BLOB", new String(blob.getBytes(1, (int) blob.length()))); + assertEquals(101, rs.getInt(4)); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertWithBlobClobObjectArraySource() throws Exception { + String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + + "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))"; + + recreateTable(createTableWithBlob); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + byte[] bytes = "BLOB".getBytes(); + Object[] blobRecordValue = new Object[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + blobRecordValue[i] = bytes[i]; + } + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + parser.addSchemaField("content", RecordFieldType.ARRAY); + + parser.addRecord(1, "rec1", 101, blobRecordValue); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + Clob clob = rs.getClob(2); + assertNotNull(clob); + char[] clobText = new char[5]; + int numBytes = clob.getCharacterStream().read(clobText); + assertEquals(4, numBytes); + // Ignore last character, it"s meant to ensure that only 4 bytes were read even though the buffer is 5 bytes + assertEquals("rec1", new String(clobText).substring(0, 4)); + Blob blob = rs.getBlob(3); + assertEquals("BLOB", new String(blob.getBytes(1, (int) blob.length()))); + assertEquals(101, rs.getInt(4)); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertWithBlobStringSource() throws Exception { + String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + + "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))"; + + recreateTable(createTableWithBlob); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + parser.addSchemaField("content", RecordFieldType.STRING); + + parser.addRecord(1, "rec1", 101, "BLOB"); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1); + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + final ResultSet rs = stmt.executeQuery("SELECT * FROM PERSONS"); + assertTrue(rs.next()); + assertEquals(1, rs.getInt(1)); + Clob clob = rs.getClob(2); + assertNotNull(clob); + char[] clobText = new char[5]; + int numBytes = clob.getCharacterStream().read(clobText); + assertEquals(4, numBytes); + // Ignore last character, it"s meant to ensure that only 4 bytes were read even though the buffer is 5 bytes + assertEquals("rec1", new String(clobText).substring(0, 4)); + Blob blob = rs.getBlob(3); + assertEquals("BLOB", new String(blob.getBytes(1, (int) blob.length()))); + assertEquals(101, rs.getInt(4)); + + stmt.close(); + conn.close(); + } + + @Test + void testInsertWithBlobIntegerArraySource() throws Exception { + String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + + "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))"; + + recreateTable(createTableWithBlob); + final MockRecordParser parser = new MockRecordParser(); + runner.addControllerService("parser", parser); + runner.enableControllerService(parser); + + parser.addSchemaField("id", RecordFieldType.INT); + parser.addSchemaField("name", RecordFieldType.STRING); + parser.addSchemaField("code", RecordFieldType.INT); + parser.addSchemaField("content", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.INT.getDataType()).getFieldType()); + + parser.addRecord(1, "rec1", 101, new Integer[] {1, 2, 3}); + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, "parser"); + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE); + runner.setProperty(PutDatabaseRecord.TABLE_NAME, "PERSONS"); + + runner.enqueue(new byte[0]); + runner.run(); + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0); + runner.assertTransferCount(PutDatabaseRecord.REL_RETRY, 0); + runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1); + } + private void recreateTable() throws ProcessException { try (final Connection conn = dbcp.getConnection(); final Statement stmt = conn.createStatement()) { @@ -182,6 +1769,52 @@ public class PutDatabaseRecordTest { } } + private int getTableSize() throws SQLException { + try (final Connection connection = dbcp.getConnection()) { + try (final Statement stmt = connection.createStatement()) { + final ResultSet rs = stmt.executeQuery("SELECT count(*) FROM PERSONS"); + assertTrue(rs.next()); + return rs.getInt(1); + } + } + } + + private void recreateTable(String createSQL) throws ProcessException, SQLException { + final Connection conn = dbcp.getConnection(); + final Statement stmt = conn.createStatement(); + try { + stmt.execute("drop table PERSONS"); + } catch (SQLException ignore) { + // Do nothing, may not have existed + } + stmt.execute(createSQL); + stmt.close(); + conn.close(); + } + + private Map createValues(final int id, final String name, final int code) { + final Map values = new HashMap<>(); + values.put("id", id); + values.put("name", name); + values.put("code", code); + return values; + } + + private Supplier createPreparedStatementSpy() { + final PreparedStatement[] spyStmt = new PreparedStatement[1]; + final Answer answer = (inv) -> new DelegatingConnection((Connection) inv.callRealMethod()) { + @Override + public PreparedStatement prepareStatement(String sql) throws SQLException { + spyStmt[0] = spy(getDelegate().prepareStatement(sql)); + return spyStmt[0]; + } + }; + doAnswer(answer).when(dbcp).getConnection(ArgumentMatchers.anyMap()); + return () -> spyStmt[0]; + } + + + static class PutDatabaseRecordUnmatchedField extends PutDatabaseRecord { @Override SqlAndIncludedColumns generateInsert(RecordSchema recordSchema, String tableName, TableSchema tableSchema, DMLSettings settings) throws IllegalArgumentException { diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/SplitXmlTest.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/SplitXmlTest.java new file mode 100644 index 0000000000..95ccf7524d --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/SplitXmlTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.standard; + +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.file.Paths; + +public class SplitXmlTest { + private TestRunner runner; + + @BeforeEach + void setupRunner() { + runner = TestRunners.newTestRunner(new SplitXml()); + } + + @Test + void testShouldHandleXXEInTemplate() throws IOException { + final String xxeTemplateFilepath = "src/test/resources/xxe_template.xml"; + assertExternalEntitiesFailure(xxeTemplateFilepath); + } + + @Test + void testShouldHandleRemoteCallXXE() throws IOException { + final String xxeTemplateFilepath = "src/test/resources/xxe_from_report.xml"; + assertExternalEntitiesFailure(xxeTemplateFilepath); + } + + private void assertExternalEntitiesFailure(final String filePath) throws IOException { + runner.setProperty(SplitXml.SPLIT_DEPTH, "3"); + runner.enqueue(Paths.get(filePath)); + + runner.run(); + + runner.assertAllFlowFilesTransferred(SplitXml.REL_FAILURE); + } +} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestCalculateRecordStats.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestCalculateRecordStats.java new file mode 100644 index 0000000000..83834cdc85 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestCalculateRecordStats.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.processors.standard; + +import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.serialization.SimpleRecordSchema; + +import org.apache.nifi.serialization.record.MapRecord; +import org.apache.nifi.serialization.record.MockRecordParser; +import org.apache.nifi.serialization.record.RecordField; +import org.apache.nifi.serialization.record.RecordFieldType; +import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.util.MockFlowFile; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class TestCalculateRecordStats { + TestRunner runner; + MockRecordParser recordParser; + RecordSchema personSchema; + + @BeforeEach + void setup() throws InitializationException { + runner = TestRunners.newTestRunner(CalculateRecordStats.class); + recordParser = new MockRecordParser(); + runner.addControllerService("recordReader", recordParser); + runner.setProperty(CalculateRecordStats.RECORD_READER, "recordReader"); + runner.enableControllerService(recordParser); + runner.assertValid(); + + recordParser.addSchemaField("id", RecordFieldType.INT); + List personFields = new ArrayList<>(); + RecordField nameField = new RecordField("name", RecordFieldType.STRING.getDataType()); + RecordField ageField = new RecordField("age", RecordFieldType.INT.getDataType()); + RecordField sportField = new RecordField("sport", RecordFieldType.STRING.getDataType()); + personFields.add(nameField); + personFields.add(ageField); + personFields.add(sportField); + personSchema = new SimpleRecordSchema(personFields); + recordParser.addSchemaField("person", RecordFieldType.RECORD); + } + + @Test + void testNoNullOrEmptyRecordFields() { + final List sports = Arrays.asList("Soccer", "Soccer", "Soccer", "Football", "Football", "Basketball"); + final Map expectedAttributes = new HashMap<>(); + expectedAttributes.put("recordStats.sport.Soccer", "3"); + expectedAttributes.put("recordStats.sport.Football", "2"); + expectedAttributes.put("recordStats.sport.Basketball", "1"); + expectedAttributes.put("recordStats.sport", "6"); + expectedAttributes.put("record.count", "6"); + + commonTest(Collections.singletonMap("sport", "/person/sport"), sports, expectedAttributes); + } + + @Test + void testWithNullFields() { + final List sports = Arrays.asList("Soccer", null, null, "Football", null, "Basketball"); + final Map expectedAttributes = new HashMap<>(); + expectedAttributes.put("recordStats.sport.Soccer", "1"); + expectedAttributes.put("recordStats.sport.Football", "1"); + expectedAttributes.put("recordStats.sport.Basketball", "1"); + expectedAttributes.put("recordStats.sport", "3"); + expectedAttributes.put("record.count", "6"); + + commonTest(Collections.singletonMap("sport", "/person/sport"), sports, expectedAttributes); + } + + @Test + void testWithFilters() { + final List sports = Arrays.asList("Soccer", "Soccer", "Soccer", "Football", "Football", "Basketball"); + final Map expectedAttributes = new HashMap<>(); + expectedAttributes.put("recordStats.sport.Soccer", "3"); + expectedAttributes.put("recordStats.sport.Basketball", "1"); + expectedAttributes.put("recordStats.sport", "4"); + expectedAttributes.put("record.count", "6"); + + final Map propz = Collections.singletonMap("sport", "/person/sport[. != 'Football']"); + + commonTest(propz, sports, expectedAttributes); + } + + @Test + void testWithSizeLimit() { + runner.setProperty(CalculateRecordStats.LIMIT, "3"); + final List sports = Arrays.asList("Soccer", "Soccer", "Soccer", "Football", "Football", + "Basketball", "Baseball", "Baseball", "Baseball", "Baseball", + "Skiing", "Skiing", "Skiing", "Snowboarding"); + final Map expectedAttributes = new HashMap<>(); + expectedAttributes.put("recordStats.sport.Skiing", "3"); + expectedAttributes.put("recordStats.sport.Soccer", "3"); + expectedAttributes.put("recordStats.sport.Baseball", "4"); + expectedAttributes.put("recordStats.sport", String.valueOf(sports.size())); + expectedAttributes.put("record.count", String.valueOf(sports.size())); + + final Map propz = Collections.singletonMap("sport", "/person/sport"); + + commonTest(propz, sports, expectedAttributes); + } + + private void commonTest(Map procProperties, List sports, Map expectedAttributes) { + int index = 1; + for (final String sport : sports) { + final Map newRecord = new HashMap<>(); + newRecord.put("name", "John Doe"); + newRecord.put("age", 48); + newRecord.put("sport", sport); + recordParser.addRecord(index++, new MapRecord(personSchema, newRecord)); + } + + for (final Map.Entry property : procProperties.entrySet()) { + runner.setProperty(property.getKey(), property.getValue()); + } + + runner.enqueue(""); + runner.run(); + runner.assertTransferCount(CalculateRecordStats.REL_FAILURE, 0); + runner.assertTransferCount(CalculateRecordStats.REL_SUCCESS, 1); + + final List flowFiles = runner.getFlowFilesForRelationship(CalculateRecordStats.REL_SUCCESS); + final MockFlowFile ff = flowFiles.get(0); + for (final Map.Entry expectedAttribute : expectedAttributes.entrySet()) { + final String key = expectedAttribute.getKey(); + final String value = expectedAttribute.getValue(); + assertNotNull(ff.getAttribute(key), String.format("Missing %s", key)); + assertEquals(value, ff.getAttribute(key)); + } + } +} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestEncryptContent.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestEncryptContent.java index 8cdb3b59b2..afa00d05b9 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestEncryptContent.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestEncryptContent.java @@ -16,14 +16,23 @@ */ package org.apache.nifi.processors.standard; +import groovy.time.TimeCategory; +import groovy.time.TimeDuration; +import org.apache.commons.codec.DecoderException; import org.apache.commons.codec.binary.Hex; import org.apache.nifi.components.AllowableValue; import org.apache.nifi.components.ValidationResult; import org.apache.nifi.security.util.EncryptionMethod; import org.apache.nifi.security.util.KeyDerivationFunction; +import org.apache.nifi.security.util.crypto.Argon2CipherProvider; +import org.apache.nifi.security.util.crypto.Argon2SecureHasher; +import org.apache.nifi.security.util.crypto.CipherUtility; +import org.apache.nifi.security.util.crypto.KeyedEncryptor; import org.apache.nifi.security.util.crypto.PasswordBasedEncryptor; +import org.apache.nifi.security.util.crypto.RandomIVPBECipherProvider; import org.apache.nifi.util.MockFlowFile; import org.apache.nifi.util.MockProcessContext; +import org.apache.nifi.util.StringUtils; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; import org.bouncycastle.bcpg.BCPGInputStream; @@ -31,10 +40,10 @@ import org.bouncycastle.bcpg.SymmetricKeyEncSessionPacket; import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import javax.crypto.Cipher; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; import java.io.InputStream; @@ -42,34 +51,42 @@ import java.lang.reflect.Method; import java.nio.charset.StandardCharsets; import java.nio.file.Paths; import java.security.Security; +import java.text.ParseException; +import java.text.SimpleDateFormat; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.Date; import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; import static org.bouncycastle.openpgp.PGPUtil.getDecoderStream; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; public class TestEncryptContent { - private static final Logger logger = LoggerFactory.getLogger(TestEncryptContent.class); - private static AllowableValue[] getPGPCipherList() { try{ Method method = EncryptContent.class.getDeclaredMethod("buildPGPSymmetricCipherAllowableValues"); method.setAccessible(true); return ((AllowableValue[]) method.invoke(null)); } catch (Exception e){ - logger.error("Cannot access buildPGPSymmetricCipherAllowableValues", e); fail("Cannot access buildPGPSymmetricCipherAllowableValues"); } return null; } + private static final List SUPPORTED_KEYED_ENCRYPTION_METHODS = Arrays + .stream(EncryptionMethod.values()) + .filter(method -> method.isKeyedCipher() && method != EncryptionMethod.AES_CBC_NO_PADDING) + .collect(Collectors.toList()); + @BeforeEach public void setUp() { Security.addProvider(new BouncyCastleProvider()); @@ -93,7 +110,6 @@ public class TestEncryptContent { continue; } - logger.info("Attempting {}", encryptionMethod.name()); testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); @@ -112,13 +128,43 @@ public class TestEncryptContent { testRunner.run(); testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1); - logger.info("Successfully decrypted {}", encryptionMethod.name()); + flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); + flowFile.assertContentEquals(new File("src/test/resources/hello.txt")); + } + } + + @Test + public void testKeyedCiphersRoundTrip() throws IOException { + final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); + final String RAW_KEY_HEX = StringUtils.repeat("ab", 16); + testRunner.setProperty(EncryptContent.RAW_KEY_HEX, RAW_KEY_HEX); + testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NONE.name()); + + for (final EncryptionMethod encryptionMethod : SUPPORTED_KEYED_ENCRYPTION_METHODS) { + testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + + testRunner.enqueue(Paths.get("src/test/resources/hello.txt")); + testRunner.clearTransferState(); + testRunner.run(); + + testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1); + + MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); + testRunner.assertQueueEmpty(); + + testRunner.setProperty(EncryptContent.MODE, EncryptContent.DECRYPT_MODE); + testRunner.enqueue(flowFile); + testRunner.clearTransferState(); + testRunner.run(); + testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1); flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); flowFile.assertContentEquals(new File("src/test/resources/hello.txt")); } } + @Test public void testPGPCiphersRoundTrip() { final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); @@ -197,27 +243,22 @@ public class TestEncryptContent { } @Test - public void testShouldDetermineMaxKeySizeForAlgorithms() throws IOException { - // Arrange + public void testShouldDetermineMaxKeySizeForAlgorithms() { final String AES_ALGORITHM = EncryptionMethod.MD5_256AES.getAlgorithm(); final String DES_ALGORITHM = EncryptionMethod.MD5_DES.getAlgorithm(); final int AES_MAX_LENGTH = Integer.MAX_VALUE; final int DES_MAX_LENGTH = Integer.MAX_VALUE; - // Act int determinedAESMaxLength = PasswordBasedEncryptor.getMaxAllowedKeyLength(AES_ALGORITHM); int determinedTDESMaxLength = PasswordBasedEncryptor.getMaxAllowedKeyLength(DES_ALGORITHM); - // Assert - assert determinedAESMaxLength == AES_MAX_LENGTH; - assert determinedTDESMaxLength == DES_MAX_LENGTH; + assertEquals(AES_MAX_LENGTH, determinedAESMaxLength); + assertEquals(DES_MAX_LENGTH, determinedTDESMaxLength); } @Test public void testShouldDecryptOpenSSLRawSalted() throws IOException { - // Arrange - final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); final String password = "thisIsABadPassword"; @@ -229,27 +270,20 @@ public class TestEncryptContent { testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, method.name()); testRunner.setProperty(EncryptContent.MODE, EncryptContent.DECRYPT_MODE); - // Act testRunner.enqueue(Paths.get("src/test/resources/TestEncryptContent/salted_raw.enc")); testRunner.clearTransferState(); testRunner.run(); - // Assert testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1); testRunner.assertQueueEmpty(); MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); - logger.info("Decrypted contents (hex): {}", Hex.encodeHexString(flowFile.toByteArray())); - logger.info("Decrypted contents: {}", new String(flowFile.toByteArray(), StandardCharsets.UTF_8)); - // Assert flowFile.assertContentEquals(new File("src/test/resources/TestEncryptContent/plain.txt")); } @Test public void testShouldDecryptOpenSSLRawUnsalted() throws IOException { - // Arrange - final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); final String password = "thisIsABadPassword"; @@ -261,33 +295,18 @@ public class TestEncryptContent { testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, method.name()); testRunner.setProperty(EncryptContent.MODE, EncryptContent.DECRYPT_MODE); - // Act testRunner.enqueue(Paths.get("src/test/resources/TestEncryptContent/unsalted_raw.enc")); testRunner.clearTransferState(); testRunner.run(); - // Assert testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1); testRunner.assertQueueEmpty(); MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); - logger.info("Decrypted contents (hex): {}", Hex.encodeHexString(flowFile.toByteArray())); - logger.info("Decrypted contents: {}", new String(flowFile.toByteArray(), StandardCharsets.UTF_8)); - // Assert flowFile.assertContentEquals(new File("src/test/resources/TestEncryptContent/plain.txt")); } - @Test - public void testDecryptShouldDefaultToNone() throws IOException { - // Arrange - final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); - - // Assert - assertEquals(testRunner.getProcessor().getPropertyDescriptor(EncryptContent.KEY_DERIVATION_FUNCTION - .getName()).getDefaultValue(), KeyDerivationFunction.NONE.name(), "Decrypt should default to None"); - } - @Test public void testDecryptSmallerThanSaltSize() { final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class); @@ -447,10 +466,6 @@ public class TestEncryptContent { pc = (MockProcessContext) runner.getProcessContext(); results = pc.validate(); - for (ValidationResult vr : results) { - logger.info(vr.toString()); - } - // The default validation error is: // Raw key hex cannot be empty final String RAW_KEY_ERROR = "'raw-key-hex' is invalid because Raw Key (hexadecimal) is " + @@ -552,4 +567,588 @@ public class TestEncryptContent { runner.removeProperty(EncryptContent.PGP_SYMMETRIC_ENCRYPTION_CIPHER); runner.assertValid(); } + + @Test + void testShouldValidateMaxKeySizeForAlgorithmsOnUnlimitedStrengthJVM() { + final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class); + Collection results; + MockProcessContext pc; + + EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC; + + // Integer.MAX_VALUE or 128, so use 256 or 128 + final int MAX_KEY_LENGTH = Math.min(PasswordBasedEncryptor.getMaxAllowedKeyLength(encryptionMethod.getAlgorithm()), 256); + final String TOO_LONG_KEY_HEX = StringUtils.repeat("ab", (MAX_KEY_LENGTH / 8 + 1)); + + runner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NONE.name()); + runner.setProperty(EncryptContent.RAW_KEY_HEX, TOO_LONG_KEY_HEX); + + runner.enqueue(new byte[0]); + pc = (MockProcessContext) runner.getProcessContext(); + + results = pc.validate(); + + assertEquals(1, results.size()); + ValidationResult vr = results.iterator().next(); + + String expectedResult = "'raw-key-hex' is invalid because Key must be valid length [128, 192, 256]"; + String message = "'" + vr.toString() + "' contains '" + expectedResult + "'"; + assertTrue(vr.toString().contains(expectedResult), message); + } + + @Test + void testShouldValidateKeyFormatAndSizeForAlgorithms() { + final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class); + Collection results; + MockProcessContext pc; + + EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC; + + final int INVALID_KEY_LENGTH = 120; + final String INVALID_KEY_HEX = StringUtils.repeat("ab", (INVALID_KEY_LENGTH / 8)); + + runner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NONE.name()); + runner.setProperty(EncryptContent.RAW_KEY_HEX, INVALID_KEY_HEX); + + runner.enqueue(new byte[0]); + pc = (MockProcessContext) runner.getProcessContext(); + + results = pc.validate(); + + assertEquals(1, results.size()); + ValidationResult keyLengthInvalidVR = results.iterator().next(); + + String expectedResult = "'raw-key-hex' is invalid because Key must be valid length [128, 192, 256]"; + String message = "'" + keyLengthInvalidVR.toString() + "' contains '" + expectedResult + "'"; + assertTrue(keyLengthInvalidVR.toString().contains(expectedResult), message); + } + + @Test + void testShouldValidateKDFWhenKeyedCipherSelected() { + final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class); + Collection results; + MockProcessContext pc; + + final int VALID_KEY_LENGTH = 128; + final String VALID_KEY_HEX = StringUtils.repeat("ab", (VALID_KEY_LENGTH / 8)); + + runner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + + for (final EncryptionMethod encryptionMethod : SUPPORTED_KEYED_ENCRYPTION_METHODS) { + runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + + // Scenario 1: Legacy KDF + keyed cipher -> validation error + final List invalidKDFs = Arrays.asList(KeyDerivationFunction.NIFI_LEGACY, KeyDerivationFunction.OPENSSL_EVP_BYTES_TO_KEY); + for (final KeyDerivationFunction invalidKDF : invalidKDFs) { + runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, invalidKDF.name()); + runner.setProperty(EncryptContent.RAW_KEY_HEX, VALID_KEY_HEX); + runner.removeProperty(EncryptContent.PASSWORD); + + runner.enqueue(new byte[0]); + pc = (MockProcessContext) runner.getProcessContext(); + + results = pc.validate(); + + assertEquals(1, results.size()); + ValidationResult keyLengthInvalidVR = results.iterator().next(); + + String expectedResult = String.format("'key-derivation-function' is invalid because Key Derivation Function is required to be BCRYPT, SCRYPT, PBKDF2, ARGON2, NONE when using " + + "algorithm %s", encryptionMethod.getAlgorithm()); + String message = "'" + keyLengthInvalidVR.toString() + "' contains '" + expectedResult + "'"; + assertTrue(keyLengthInvalidVR.toString().contains(expectedResult), message); + } + + // Scenario 2: No KDF + keyed cipher + raw-key-hex -> valid + + runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NONE.name()); + runner.setProperty(EncryptContent.RAW_KEY_HEX, VALID_KEY_HEX); + runner.removeProperty(EncryptContent.PASSWORD); + + runner.enqueue(new byte[0]); + pc = (MockProcessContext) runner.getProcessContext(); + + results = pc.validate(); + + assertTrue(results.isEmpty()); + + // Scenario 3: Strong KDF + keyed cipher + password -> valid + final List validKDFs = Arrays.asList(KeyDerivationFunction.BCRYPT, + KeyDerivationFunction.SCRYPT, + KeyDerivationFunction.PBKDF2, + KeyDerivationFunction.ARGON2); + for (final KeyDerivationFunction validKDF : validKDFs) { + runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, validKDF.name()); + runner.setProperty(EncryptContent.PASSWORD, "thisIsABadPassword"); + runner.removeProperty(EncryptContent.RAW_KEY_HEX); + + runner.enqueue(new byte[0]); + pc = (MockProcessContext) runner.getProcessContext(); + + results = pc.validate(); + + assertTrue(results.isEmpty()); + } + } + } + + @Test + void testShouldValidateKeyMaterialSourceWhenKeyedCipherSelected() { + final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class); + Collection results; + MockProcessContext pc; + + final int VALID_KEY_LENGTH = 128; + final String VALID_KEY_HEX = StringUtils.repeat("ab", (VALID_KEY_LENGTH / 8)); + + final String VALID_PASSWORD = "thisIsABadPassword"; + + runner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + KeyDerivationFunction none = KeyDerivationFunction.NONE; + + // Scenario 1 - RKH w/ KDF NONE & em in [CBC, CTR, GCM] (no password) + for (final EncryptionMethod kem : SUPPORTED_KEYED_ENCRYPTION_METHODS) { + runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, kem.name()); + runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, none.name()); + + runner.setProperty(EncryptContent.RAW_KEY_HEX, VALID_KEY_HEX); + runner.removeProperty(EncryptContent.PASSWORD); + + runner.enqueue(new byte[0]); + pc = (MockProcessContext) runner.getProcessContext(); + + results = pc.validate(); + + assertTrue(results.isEmpty()); + + // Scenario 2 - PW w/ KDF in [BCRYPT, SCRYPT, PBKDF2, ARGON2] & em in [CBC, CTR, GCM] (no RKH) + final List validKDFs = Arrays + .stream(KeyDerivationFunction.values()) + .filter(it -> it.isStrongKDF()) + .collect(Collectors.toList()); + for (final KeyDerivationFunction kdf : validKDFs) { + runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, kem.name()); + runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, kdf.name()); + + runner.removeProperty(EncryptContent.RAW_KEY_HEX); + runner.setProperty(EncryptContent.PASSWORD, VALID_PASSWORD); + + runner.enqueue(new byte[0]); + pc = (MockProcessContext) runner.getProcessContext(); + + results = pc.validate(); + + assertTrue(results.isEmpty()); + } + } + } + + @Test + void testShouldValidateKDFWhenPBECipherSelected() { + final TestRunner runner = TestRunners.newTestRunner(EncryptContent.class); + Collection results; + MockProcessContext pc; + final String PASSWORD = "short"; + + final List encryptionMethods = Arrays + .stream(EncryptionMethod.values()) + .filter(it -> it.getAlgorithm().startsWith("PBE")) + .collect(Collectors.toList()); + + runner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + runner.setProperty(EncryptContent.PASSWORD, PASSWORD); + runner.setProperty(EncryptContent.ALLOW_WEAK_CRYPTO, "allowed"); + + for (final EncryptionMethod encryptionMethod : encryptionMethods) { + runner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + + final List invalidKDFs = Arrays.asList( + KeyDerivationFunction.NONE, + KeyDerivationFunction.BCRYPT, + KeyDerivationFunction.SCRYPT, + KeyDerivationFunction.PBKDF2, + KeyDerivationFunction.ARGON2 + ); + for (final KeyDerivationFunction invalidKDF : invalidKDFs) { + runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, invalidKDF.name()); + + runner.enqueue(new byte[0]); + pc = (MockProcessContext) runner.getProcessContext(); + + results = pc.validate(); + + assertEquals(1, results.size()); + ValidationResult keyLengthInvalidVR = results.iterator().next(); + + String expectedResult = String.format("'Key Derivation Function' is invalid because Key Derivation Function is required to be NIFI_LEGACY, OPENSSL_EVP_BYTES_TO_KEY when using " + + "algorithm %s", encryptionMethod.getAlgorithm()); + String message = "'" + keyLengthInvalidVR.toString() + "' contains '" + expectedResult + "'"; + assertTrue(keyLengthInvalidVR.toString().contains(expectedResult), message); + } + + final List validKDFs = Arrays.asList(KeyDerivationFunction.NIFI_LEGACY, KeyDerivationFunction.OPENSSL_EVP_BYTES_TO_KEY); + for (final KeyDerivationFunction validKDF : validKDFs) { + runner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, validKDF.name()); + + runner.enqueue(new byte[0]); + pc = (MockProcessContext) runner.getProcessContext(); + + results = pc.validate(); + + assertEquals(0, results.size()); + } + } + } + + @Test + void testDecryptAesCbcNoPadding() throws DecoderException, IOException { + final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); + final String RAW_KEY_HEX = StringUtils.repeat("ab", 16); + testRunner.setProperty(EncryptContent.RAW_KEY_HEX, RAW_KEY_HEX); + testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NONE.name()); + testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, EncryptionMethod.AES_CBC_NO_PADDING.name()); + testRunner.setProperty(EncryptContent.MODE, EncryptContent.DECRYPT_MODE); + + final String content = "ExactBlockSizeRequiredForProcess"; + final byte[] bytes = content.getBytes(StandardCharsets.UTF_8); + final ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes); + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + + final KeyedEncryptor encryptor = new KeyedEncryptor(EncryptionMethod.AES_CBC_NO_PADDING, Hex.decodeHex(RAW_KEY_HEX)); + encryptor.getEncryptionCallback().process(inputStream, outputStream); + outputStream.close(); + + final byte[] encrypted = outputStream.toByteArray(); + testRunner.enqueue(encrypted); + testRunner.run(); + + testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1); + MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); + flowFile.assertContentEquals(content); + } + + @Test + void testArgon2EncryptionShouldWriteAttributesWithEncryptionMetadata() throws ParseException { + final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); + KeyDerivationFunction kdf = KeyDerivationFunction.ARGON2; + EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC; + + testRunner.setProperty(EncryptContent.PASSWORD, "thisIsABadPassword"); + testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, kdf.name()); + testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + + String PLAINTEXT = "This is a plaintext message. "; + + testRunner.enqueue(PLAINTEXT); + testRunner.clearTransferState(); + testRunner.run(); + + testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1); + + MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); + testRunner.assertQueueEmpty(); + + byte[] flowfileContentBytes = flowFile.getData(); + + int ivDelimiterStart = CipherUtility.findSequence(flowfileContentBytes, RandomIVPBECipherProvider.IV_DELIMITER); + + final byte[] EXPECTED_KDF_SALT_BYTES = extractFullSaltFromCipherBytes(flowfileContentBytes); + final String EXPECTED_KDF_SALT = new String(EXPECTED_KDF_SALT_BYTES); + final String EXPECTED_SALT_HEX = extractRawSaltHexFromFullSalt(EXPECTED_KDF_SALT_BYTES, kdf); + + final String EXPECTED_IV_HEX = Hex.encodeHexString(Arrays.copyOfRange(flowfileContentBytes, ivDelimiterStart - 16, ivDelimiterStart)); + + // Assert the timestamp attribute was written and is accurate + final TimeDuration diff = calculateTimestampDifference(new Date(), flowFile.getAttribute("encryptcontent.timestamp")); + assertTrue(diff.toMilliseconds() < 1_000); + assertEquals(encryptionMethod.name(), flowFile.getAttribute("encryptcontent.algorithm")); + assertEquals(kdf.name(), flowFile.getAttribute("encryptcontent.kdf")); + assertEquals("encrypted", flowFile.getAttribute("encryptcontent.action")); + assertEquals(EXPECTED_SALT_HEX, flowFile.getAttribute("encryptcontent.salt")); + assertEquals("16", flowFile.getAttribute("encryptcontent.salt_length")); + assertEquals(EXPECTED_KDF_SALT, flowFile.getAttribute("encryptcontent.kdf_salt")); + final int kdfSaltLength = Integer.valueOf(flowFile.getAttribute("encryptcontent.kdf_salt_length")); + assertTrue(kdfSaltLength >= 29 && kdfSaltLength <= 54); + assertEquals(EXPECTED_IV_HEX, flowFile.getAttribute("encryptcontent.iv")); + assertEquals("16", flowFile.getAttribute("encryptcontent.iv_length")); + assertEquals(String.valueOf(PLAINTEXT.length()), flowFile.getAttribute("encryptcontent.plaintext_length")); + assertEquals(String.valueOf(flowfileContentBytes.length), flowFile.getAttribute("encryptcontent.cipher_text_length")); + } + + @Test + void testKeyedEncryptionShouldWriteAttributesWithEncryptionMetadata() throws ParseException { + // Arrange + final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); + KeyDerivationFunction kdf = KeyDerivationFunction.NONE; + EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC; + + testRunner.setProperty(EncryptContent.RAW_KEY_HEX, "0123456789ABCDEFFEDCBA9876543210"); + testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, kdf.name()); + testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + + String PLAINTEXT = "This is a plaintext message. "; + + testRunner.enqueue(PLAINTEXT); + testRunner.clearTransferState(); + testRunner.run(); + + testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1); + + MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); + testRunner.assertQueueEmpty(); + + byte[] flowfileContentBytes = flowFile.getData(); + + int ivDelimiterStart = CipherUtility.findSequence(flowfileContentBytes, RandomIVPBECipherProvider.IV_DELIMITER); + assertEquals(16, ivDelimiterStart); + + final TimeDuration diff = calculateTimestampDifference(new Date(), flowFile.getAttribute("encryptcontent.timestamp")); + + // Assert the timestamp attribute was written and is accurate + assertTrue(diff.toMilliseconds() < 1_000); + + final String EXPECTED_IV_HEX = Hex.encodeHexString(Arrays.copyOfRange(flowfileContentBytes, 0, ivDelimiterStart)); + final int EXPECTED_CIPHER_TEXT_LENGTH = CipherUtility.calculateCipherTextLength(PLAINTEXT.length(), 0); + assertEquals(encryptionMethod.name(), flowFile.getAttribute("encryptcontent.algorithm")); + assertEquals(kdf.name(), flowFile.getAttribute("encryptcontent.kdf")); + assertEquals("encrypted", flowFile.getAttribute("encryptcontent.action")); + assertEquals(EXPECTED_IV_HEX, flowFile.getAttribute("encryptcontent.iv")); + assertEquals("16", flowFile.getAttribute("encryptcontent.iv_length")); + assertEquals(String.valueOf(PLAINTEXT.length()), flowFile.getAttribute("encryptcontent.plaintext_length")); + assertEquals(String.valueOf(EXPECTED_CIPHER_TEXT_LENGTH), flowFile.getAttribute("encryptcontent.cipher_text_length")); + } + + @Test + void testDifferentCompatibleConfigurations() throws Exception { + final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); + KeyDerivationFunction argon2 = KeyDerivationFunction.ARGON2; + EncryptionMethod aesCbcEM = EncryptionMethod.AES_CBC; + int keyLength = CipherUtility.parseKeyLengthFromAlgorithm(aesCbcEM.getAlgorithm()); + + final String PASSWORD = "thisIsABadPassword"; + testRunner.setProperty(EncryptContent.PASSWORD, PASSWORD); + testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, argon2.name()); + testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, aesCbcEM.name()); + testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + + String PLAINTEXT = "This is a plaintext message. "; + + testRunner.enqueue(PLAINTEXT); + testRunner.clearTransferState(); + testRunner.run(); + + MockFlowFile encryptedFlowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); + byte[] fullCipherBytes = encryptedFlowFile.getData(); + + // Extract the KDF salt from the encryption metadata in the flowfile attribute + String argon2Salt = encryptedFlowFile.getAttribute("encryptcontent.kdf_salt"); + Argon2SecureHasher a2sh = new Argon2SecureHasher(keyLength / 8); + byte[] fullSaltBytes = argon2Salt.getBytes(StandardCharsets.UTF_8); + byte[] rawSaltBytes = Hex.decodeHex(encryptedFlowFile.getAttribute("encryptcontent.salt")); + byte[] keyBytes = a2sh.hashRaw(PASSWORD.getBytes(StandardCharsets.UTF_8), rawSaltBytes); + String keyHex = Hex.encodeHexString(keyBytes); + + byte[] ivBytes = Hex.decodeHex(encryptedFlowFile.getAttribute("encryptcontent.iv")); + + // Sanity check the encryption + Argon2CipherProvider a2cp = new Argon2CipherProvider(); + Cipher sanityCipher = a2cp.getCipher(aesCbcEM, PASSWORD, fullSaltBytes, ivBytes, CipherUtility.parseKeyLengthFromAlgorithm(aesCbcEM.getAlgorithm()), false); + byte[] cipherTextBytes = Arrays.copyOfRange(fullCipherBytes, fullCipherBytes.length - 32, fullCipherBytes.length); + byte[] recoveredBytes = sanityCipher.doFinal(cipherTextBytes); + + // Configure decrypting processor with raw key + KeyDerivationFunction kdf = KeyDerivationFunction.NONE; + EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC; + + testRunner.setProperty(EncryptContent.RAW_KEY_HEX, keyHex); + testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, kdf.name()); + testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + testRunner.setProperty(EncryptContent.MODE, EncryptContent.DECRYPT_MODE); + testRunner.removeProperty(EncryptContent.PASSWORD); + + testRunner.enqueue(fullCipherBytes); + testRunner.clearTransferState(); + testRunner.run(); + + testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1); + + MockFlowFile decryptedFlowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); + testRunner.assertQueueEmpty(); + + byte[] flowfileContentBytes = decryptedFlowFile.getData(); + + assertArrayEquals(recoveredBytes, flowfileContentBytes); + } + + @Test + void testShouldCheckLengthOfPasswordWhenNotAllowed() { + final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); + testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NIFI_LEGACY.name()); + + Collection results; + MockProcessContext pc; + + final List encryptionMethods = Arrays + .stream(EncryptionMethod.values()) + .filter(it -> it.getAlgorithm().startsWith("PBE")) + .collect(Collectors.toList()); + + testRunner.setProperty(EncryptContent.ALLOW_WEAK_CRYPTO, "not-allowed"); + + // Use .find instead of .each to allow "breaks" using return false + for (final EncryptionMethod encryptionMethod : encryptionMethods) { + // Determine the minimum of the algorithm-accepted length or the global safe minimum to ensure only one validation result + final int shortPasswordLength = Math.min(PasswordBasedEncryptor.getMinimumSafePasswordLength() - 1, + CipherUtility.getMaximumPasswordLengthForAlgorithmOnLimitedStrengthCrypto(encryptionMethod) - 1); + String shortPassword = StringUtils.repeat("x", shortPasswordLength); + if (encryptionMethod.isUnlimitedStrength() || encryptionMethod.isKeyedCipher()) { + continue; + // cannot test unlimited strength in unit tests because it's not enabled by the JVM by default. + } + + testRunner.setProperty(EncryptContent.PASSWORD, shortPassword); + testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + + testRunner.clearTransferState(); + testRunner.enqueue(new byte[0]); + pc = (MockProcessContext) testRunner.getProcessContext(); + + results = pc.validate(); + + assertEquals(1, results.size()); + ValidationResult passwordLengthVR = results.iterator().next(); + + String expectedResult = String.format("'Password' is invalid because Password length less than %s characters is potentially unsafe. " + + "See Admin Guide.", PasswordBasedEncryptor.getMinimumSafePasswordLength()); + String message = "'" + passwordLengthVR.toString() + "' contains '" + expectedResult + "'"; + assertTrue(passwordLengthVR.toString().contains(expectedResult), message); + } + } + + @Test + void testShouldNotCheckLengthOfPasswordWhenAllowed() { + final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); + testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.NIFI_LEGACY.name()); + + Collection results; + MockProcessContext pc; + + final List encryptionMethods = Arrays + .stream(EncryptionMethod.values()) + .filter(it -> it.getAlgorithm().startsWith("PBE")) + .collect(Collectors.toList()); + + testRunner.setProperty(EncryptContent.ALLOW_WEAK_CRYPTO, "allowed"); + + for (final EncryptionMethod encryptionMethod : encryptionMethods) { + // Determine the minimum of the algorithm-accepted length or the global safe minimum to ensure only one validation result + final int shortPasswordLength = Math.min(PasswordBasedEncryptor.getMinimumSafePasswordLength() - 1, + CipherUtility.getMaximumPasswordLengthForAlgorithmOnLimitedStrengthCrypto(encryptionMethod) - 1); + String shortPassword = StringUtils.repeat("x", shortPasswordLength); + if (encryptionMethod.isUnlimitedStrength() || encryptionMethod.isKeyedCipher()) { + continue; + // cannot test unlimited strength in unit tests because it's not enabled by the JVM by default. + } + + testRunner.setProperty(EncryptContent.PASSWORD, shortPassword); + testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + + testRunner.clearTransferState(); + testRunner.enqueue(new byte[0]); + pc = (MockProcessContext) testRunner.getProcessContext(); + + results = pc.validate(); + + assertEquals(0, results.size(), results.toString()); + } + } + + @Test + void testPGPPasswordShouldSupportExpressionLanguage() { + final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); + testRunner.setProperty(EncryptContent.MODE, EncryptContent.DECRYPT_MODE); + testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, EncryptionMethod.PGP.name()); + testRunner.setProperty(EncryptContent.PRIVATE_KEYRING, "src/test/resources/TestEncryptContent/secring.gpg"); + + Collection results; + MockProcessContext pc; + + // Verify this is the correct password + final String passphraseWithoutEL = "thisIsABadPassword"; + testRunner.setProperty(EncryptContent.PRIVATE_KEYRING_PASSPHRASE, passphraseWithoutEL); + + testRunner.clearTransferState(); + testRunner.enqueue(new byte[0]); + pc = (MockProcessContext) testRunner.getProcessContext(); + + results = pc.validate(); + assertEquals(0, results.size(), results.toString()); + + final String passphraseWithEL = "${literal('thisIsABadPassword')}"; + testRunner.setProperty(EncryptContent.PRIVATE_KEYRING_PASSPHRASE, passphraseWithEL); + + testRunner.clearTransferState(); + testRunner.enqueue(new byte[0]); + + results = pc.validate(); + + assertEquals(0, results.size(), results.toString()); + } + + @Test + void testArgon2ShouldIncludeFullSalt() throws IOException { + final TestRunner testRunner = TestRunners.newTestRunner(new EncryptContent()); + testRunner.setProperty(EncryptContent.PASSWORD, "thisIsABadPassword"); + testRunner.setProperty(EncryptContent.KEY_DERIVATION_FUNCTION, KeyDerivationFunction.ARGON2.name()); + + EncryptionMethod encryptionMethod = EncryptionMethod.AES_CBC; + + testRunner.setProperty(EncryptContent.ENCRYPTION_ALGORITHM, encryptionMethod.name()); + testRunner.setProperty(EncryptContent.MODE, EncryptContent.ENCRYPT_MODE); + + testRunner.enqueue(Paths.get("src/test/resources/hello.txt")); + testRunner.clearTransferState(); + testRunner.run(); + + testRunner.assertAllFlowFilesTransferred(EncryptContent.REL_SUCCESS, 1); + + MockFlowFile flowFile = testRunner.getFlowFilesForRelationship(EncryptContent.REL_SUCCESS).get(0); + testRunner.assertQueueEmpty(); + + final String flowFileContent = flowFile.getContent(); + + final String fullSalt = flowFileContent.substring(0, flowFileContent.indexOf(new String(RandomIVPBECipherProvider.SALT_DELIMITER, StandardCharsets.UTF_8))); + + boolean isValidFormattedSalt = Argon2CipherProvider.isArgon2FormattedSalt(fullSalt); + assertTrue(isValidFormattedSalt); + + boolean fullSaltIsValidLength = fullSalt.getBytes().length >= 49 && fullSalt.getBytes().length <= 57; + assertTrue(fullSaltIsValidLength); + } + + private static byte[] extractFullSaltFromCipherBytes(byte[] cipherBytes) { + int saltDelimiterStart = CipherUtility.findSequence(cipherBytes, RandomIVPBECipherProvider.SALT_DELIMITER); + return Arrays.copyOfRange(cipherBytes, 0, saltDelimiterStart); + } + + private static String extractRawSaltHexFromFullSalt(byte[] fullSaltBytes, KeyDerivationFunction kdf) { + // Salt will be in Base64 (or Radix64) for strong KDFs + byte[] rawSaltBytes = CipherUtility.extractRawSalt(fullSaltBytes, kdf); + String rawSaltHex = Hex.encodeHexString(rawSaltBytes); + return rawSaltHex; + } + + private static TimeDuration calculateTimestampDifference(Date date, String timestamp) throws ParseException { + SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS Z"); + Date parsedTimestamp = formatter.parse(timestamp); + + return TimeCategory.minus(date, parsedTimestamp); + } } diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestFlattenJson.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestFlattenJson.java new file mode 100644 index 0000000000..cf5ac91d77 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestFlattenJson.java @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License") you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.processors.standard; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.nifi.util.MockFlowFile; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class TestFlattenJson { + private static final ObjectMapper mapper = new ObjectMapper(); + + private TestRunner testRunner; + + @BeforeEach + void setupRunner() { + testRunner = TestRunners.newTestRunner(FlattenJson.class); + } + + @Test + void testFlatten() throws JsonProcessingException { + final String json = "{\n" + + " \"test\": {\n" + + " \"msg\": \"Hello, world\"\n" + + " },\n" + + " \"first\": {\n" + + " \"second\": {\n" + + " \"third\": [\n" + + " \"one\",\n" + + " \"two\",\n" + + " \"three\",\n" + + " \"four\",\n" + + " \"five\"\n" + + " ]\n" + + " }\n" + + " }\n" + + "}"; + final Map parsed = (Map) baseTest(testRunner, json, 2); + assertEquals(parsed.get("test.msg"), "Hello, world", "test.msg should exist, but doesn't"); + assertEquals(parsed.get("first.second.third"), + Arrays.asList("one", "two", "three", "four", "five"), + "Three level block doesn't exist."); + } + + @Test + void testFlattenRecordSet() throws JsonProcessingException { + final String json = "[\n" + + " {\n" + + " \"first\": {\n" + + " \"second\": \"Hello\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"first\": {\n" + + " \"second\": \"World\"\n" + + " }\n" + + " }\n" + + "]"; + + final List expected = Arrays.asList("Hello", "World"); + final List parsed = (List) baseTest(testRunner, json, 2); + assertTrue(parsed instanceof List, "Not a list"); + for (int i = 0; i < parsed.size(); i++) { + final Map map = (Map) parsed.get(i); + assertEquals(map.get("first.second"), expected.get(i), "Missing values."); + } + } + + @Test + void testDifferentSeparator() throws JsonProcessingException { + final String json = "{\n" + + " \"first\": {\n" + + " \"second\": {\n" + + " \"third\": [\n" + + " \"one\",\n" + + " \"two\",\n" + + " \"three\",\n" + + " \"four\",\n" + + " \"five\"\n" + + " ]\n" + + " }\n" + + " }\n" + + "}"; + testRunner.setProperty(FlattenJson.SEPARATOR, "_"); + final Map parsed = (Map) baseTest(testRunner, json, 1); + + assertEquals(parsed.get("first_second_third"), + Arrays.asList("one", "two", "three", "four", "five"), + "Separator not applied."); + } + + @Test + void testExpressionLanguage() throws JsonProcessingException { + final String json = "{\n" + + " \"first\": {\n" + + " \"second\": {\n" + + " \"third\": [\n" + + " \"one\",\n" + + " \"two\",\n" + + " \"three\",\n" + + " \"four\",\n" + + " \"five\"\n" + + " ]\n" + + " }\n" + + " }\n" + + "}"; + + testRunner.setValidateExpressionUsage(true); + testRunner.setProperty(FlattenJson.SEPARATOR, "${separator.char}"); + final Map parsed = (Map) baseTest(testRunner, json, Collections.singletonMap("separator.char", "_"), 1); + assertEquals(parsed.get("first_second_third"), + Arrays.asList("one", "two", "three", "four", "five"), + "Separator not applied."); + } + + @Test + void testFlattenModeNormal() throws JsonProcessingException { + final String json = "{\n" + + " \"first\": {\n" + + " \"second\": {\n" + + " \"third\": [\n" + + " \"one\",\n" + + " \"two\",\n" + + " \"three\",\n" + + " \"four\",\n" + + " \"five\"\n" + + " ]\n" + + " }\n" + + " }\n" + + "}"; + + testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_NORMAL); + final Map parsed = (Map) baseTest(testRunner, json,5); + assertEquals("one", parsed.get("first.second.third[0]"), "Separator not applied."); + } + + @Test + void testFlattenModeKeepArrays() throws JsonProcessingException { + final String json = "{\n" + + " \"first\": {\n" + + " \"second\": [\n" + + " {\n" + + " \"x\": 1,\n" + + " \"y\": 2,\n" + + " \"z\": [\n" + + " 3,\n" + + " 4,\n" + + " 5\n" + + " ]\n" + + " },\n" + + " [\n" + + " 6,\n" + + " 7,\n" + + " 8\n" + + " ],\n" + + " [\n" + + " [\n" + + " 9,\n" + + " 10\n" + + " ],\n" + + " 11,\n" + + " 12\n" + + " ]\n" + + " ],\n" + + " \"third\": {\n" + + " \"a\": \"b\",\n" + + " \"c\": \"d\",\n" + + " \"e\": \"f\"\n" + + " }\n" + + " }\n" + + "}"; + + testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_KEEP_ARRAYS); + final Map parsed = (Map) baseTest(testRunner, json, 4); + assertInstanceOf(List.class, parsed.get("first.second")); + assertEquals(Arrays.asList(6, 7, 8), ((List) parsed.get("first.second")).get(1)); + assertEquals("b", parsed.get("first.third.a"), "Separator not applied."); + } + + @Test + void testFlattenModeKeepPrimitiveArrays() throws JsonProcessingException { + final String json = "{\n" + + " \"first\": {\n" + + " \"second\": [\n" + + " {\n" + + " \"x\": 1,\n" + + " \"y\": 2,\n" + + " \"z\": [\n" + + " 3,\n" + + " 4,\n" + + " 5\n" + + " ]\n" + + " },\n" + + " [\n" + + " 6,\n" + + " 7,\n" + + " 8\n" + + " ],\n" + + " [\n" + + " [\n" + + " 9,\n" + + " 10\n" + + " ],\n" + + " 11,\n" + + " 12\n" + + " ]\n" + + " ],\n" + + " \"third\": {\n" + + " \"a\": \"b\",\n" + + " \"c\": \"d\",\n" + + " \"e\": \"f\"\n" + + " }\n" + + " }\n" + + "}"; + + testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_KEEP_PRIMITIVE_ARRAYS); + final Map parsed = (Map) baseTest(testRunner, json, 10); + assertEquals(1, parsed.get("first.second[0].x"), "Separator not applied."); + assertEquals(Arrays.asList(3, 4, 5), parsed.get("first.second[0].z"), "Separator not applied."); + assertEquals(Arrays.asList(9, 10), parsed.get("first.second[2][0]"), "Separator not applied."); + assertEquals(11, parsed.get("first.second[2][1]"), "Separator not applied."); + assertEquals(12, parsed.get("first.second[2][2]"), "Separator not applied."); + assertEquals("d", parsed.get("first.third.c"), "Separator not applied."); + } + + @Test + void testFlattenModeDotNotation() throws JsonProcessingException { + final String json = "{\n" + + " \"first\": {\n" + + " \"second\": {\n" + + " \"third\": [\n" + + " \"one\",\n" + + " \"two\",\n" + + " \"three\",\n" + + " \"four\",\n" + + " \"five\"\n" + + " ]\n" + + " }\n" + + " }\n" + + "}"; + + testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_DOT_NOTATION); + final Map parsed = (Map) baseTest(testRunner, json, 5); + assertEquals("one", parsed.get("first.second.third.0"), "Separator not applied."); + } + + @Test + void testFlattenSlash() throws JsonProcessingException { + final String json = "{\n" + + " \"first\": {\n" + + " \"second\": {\n" + + " \"third\": [\n" + + " \"http://localhost/value1\",\n" + + " \"http://localhost/value2\"\n" + + " ]\n" + + " }\n" + + " }\n" + + "}"; + + testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_NORMAL); + final Map parsed = (Map) baseTest(testRunner, json, 2); + assertEquals("http://localhost/value1", parsed.get("first.second.third[0]"), "Separator not applied."); + } + + @Test + void testEscapeForJson() throws JsonProcessingException { + final String json = "{\n" + + " \"name\": \"Jos\\u00e9\"\n" + + "}"; + + testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_NORMAL); + final Map parsed = (Map) baseTest(testRunner, json, 1); + assertEquals("José", parsed.get("name"), "Separator not applied."); + } + + @Test + void testUnFlatten() throws JsonProcessingException { + final String json = "{\n" + + " \"test.msg\": \"Hello, world\",\n" + + " \"first.second.third\": [\n" + + " \"one\",\n" + + " \"two\",\n" + + " \"three\",\n" + + " \"four\",\n" + + " \"five\"\n" + + " ]\n" + + "}"; + + testRunner.setProperty(FlattenJson.RETURN_TYPE, FlattenJson.RETURN_TYPE_UNFLATTEN); + final Map parsed = (Map) baseTest(testRunner, json, 2); + assertEquals("Hello, world", ((Map) parsed.get("test")).get("msg")); + assertEquals(Arrays.asList("one", "two", "three", "four", "five"), + ((Map) ((Map) parsed.get("first")).get("second")).get("third")); + } + + @Test + void testUnFlattenWithDifferentSeparator() throws JsonProcessingException { + final String json = "{\n" + + " \"first_second_third\": [\n" + + " \"one\",\n" + + " \"two\",\n" + + " \"three\",\n" + + " \"four\",\n" + + " \"five\"\n" + + " ]\n" + + "}"; + + testRunner.setProperty(FlattenJson.SEPARATOR, "_"); + testRunner.setProperty(FlattenJson.RETURN_TYPE, FlattenJson.RETURN_TYPE_UNFLATTEN); + final Map parsed = (Map) baseTest(testRunner, json, 1); + assertEquals(Arrays.asList("one", "two", "three", "four", "five"), + ((Map) ((Map) parsed.get("first")).get("second")).get("third")); + } + + @Test + void testUnFlattenForKeepArraysMode() throws JsonProcessingException { + final String json = "{\n" + + " \"a.b\": 1,\n" + + " \"a.c\": [\n" + + " false,\n" + + " {\n" + + " \"i.j\": [\n" + + " false,\n" + + " true,\n" + + " \"xy\"\n" + + " ]\n" + + " }\n" + + " ]\n" + + "}"; + + testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_KEEP_ARRAYS); + testRunner.setProperty(FlattenJson.RETURN_TYPE, FlattenJson.RETURN_TYPE_UNFLATTEN); + final Map parsed = (Map) baseTest(testRunner, json, 1); + assertEquals(1, ((Map) parsed.get("a")).get("b")); + assertEquals(false, ((List) ((Map) parsed.get("a")).get("c")).get(0)); + assertEquals(Arrays.asList(false, true, "xy"), + ((Map) ((Map) ((List) ((Map) parsed.get("a")).get("c")).get(1)).get("i")).get("j")); + } + + @Test + void testUnFlattenForKeepPrimitiveArraysMode() throws JsonProcessingException { + final String json = "{\n" + + " \"first.second[0].x\": 1,\n" + + " \"first.second[0].y\": 2,\n" + + " \"first.second[0].z\": [\n" + + " 3,\n" + + " 4,\n" + + " 5\n" + + " ],\n" + + " \"first.second[1]\": [\n" + + " 6,\n" + + " 7,\n" + + " 8\n" + + " ],\n" + + " \"first.second[2][0]\": [\n" + + " 9,\n" + + " 10\n" + + " ],\n" + + " \"first.second[2][1]\": 11,\n" + + " \"first.second[2][2]\": 12,\n" + + " \"first.third.a\": \"b\",\n" + + " \"first.third.c\": \"d\",\n" + + " \"first.third.e\": \"f\"\n" + + "}"; + + testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_KEEP_PRIMITIVE_ARRAYS); + testRunner.setProperty(FlattenJson.RETURN_TYPE, FlattenJson.RETURN_TYPE_UNFLATTEN); + final Map parsed = (Map) baseTest(testRunner, json, 1); + assertEquals(1, ((Map) ((List) ((Map) parsed.get("first")).get("second")).get(0)).get("x")); + assertEquals(Arrays.asList(9, 10), ((List) ((List) ((Map) parsed.get("first")).get("second")).get(2)).get(0)); + assertEquals("d", ((Map) ((Map) parsed.get("first")).get("third")).get("c")); + } + + @Test + void testUnFlattenForDotNotationMode() throws JsonProcessingException { + final String json = "{\n" + + " \"first.second.third.0\": [\n" + + " \"one\",\n" + + " \"two\",\n" + + " \"three\",\n" + + " \"four\",\n" + + " \"five\"\n" + + " ]\n" + + "}"; + + testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_DOT_NOTATION); + testRunner.setProperty(FlattenJson.RETURN_TYPE, FlattenJson.RETURN_TYPE_UNFLATTEN); + + final Map parsed = (Map) baseTest(testRunner, json, 1); + assertEquals(Arrays.asList("one", "two", "three", "four", "five"), + ((List) ((Map) ((Map) parsed.get("first")).get("second")).get("third")).get(0)); + } + + @Test + void testFlattenWithIgnoreReservedCharacters() throws JsonProcessingException { + final String json = "{\n" + + " \"first\": {\n" + + " \"second.third\": \"Hello\",\n" + + " \"fourth\": \"World\"\n" + + " }\n" + + "}"; + + testRunner.setProperty(FlattenJson.IGNORE_RESERVED_CHARACTERS, "true"); + + final Map parsed = (Map) baseTest(testRunner, json, 2); + assertEquals("Hello", parsed.get("first.second.third"), "Separator not applied."); + assertEquals("World", parsed.get("first.fourth"), "Separator not applied."); + } + + @Test + void testFlattenRecordSetWithIgnoreReservedCharacters() throws JsonProcessingException { + final String json = "[\n" + + " {\n" + + " \"first\": {\n" + + " \"second_third\": \"Hello\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"first\": {\n" + + " \"second_third\": \"World\"\n" + + " }\n" + + " }\n" + + "]"; + testRunner.setProperty(FlattenJson.SEPARATOR, "_"); + testRunner.setProperty(FlattenJson.IGNORE_RESERVED_CHARACTERS, "true"); + + final List expected = Arrays.asList("Hello", "World"); + + final List parsed = (List) baseTest(testRunner, json, 2); + for (int i = 0; i < parsed.size(); i++) { + assertEquals(expected.get(i), ((Map) parsed.get(i)).get("first_second_third"), "Missing values."); + } + } + + @Test + void testFlattenModeNormalWithIgnoreReservedCharacters() throws JsonProcessingException { + final String json = "[\n" + + " {\n" + + " \"first\": {\n" + + " \"second_third\": \"Hello\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"first\": {\n" + + " \"second_third\": \"World\"\n" + + " }\n" + + " }\n" + + "]"; + testRunner.setProperty(FlattenJson.SEPARATOR, "_"); + testRunner.setProperty(FlattenJson.IGNORE_RESERVED_CHARACTERS, "true"); + testRunner.setProperty(FlattenJson.FLATTEN_MODE, FlattenJson.FLATTEN_MODE_NORMAL); + + final Map parsed = (Map) baseTest(testRunner, json, 2); + assertEquals("Hello", parsed.get("[0]_first_second_third"), "Separator not applied."); + assertEquals("World", parsed.get("[1]_first_second_third"), "Separator not applied."); + } + + private Object baseTest(TestRunner testRunner, String json, int keyCount) throws JsonProcessingException { + return baseTest(testRunner, json, Collections.emptyMap(), keyCount); + } + + private Object baseTest(TestRunner testRunner, String json, Map attrs, int keyCount) throws JsonProcessingException { + testRunner.enqueue(json, attrs); + testRunner.run(1, true); + testRunner.assertTransferCount(FlattenJson.REL_FAILURE, 0); + testRunner.assertTransferCount(FlattenJson.REL_SUCCESS, 1); + + final List flowFiles = testRunner.getFlowFilesForRelationship(FlattenJson.REL_SUCCESS); + final byte[] content = testRunner.getContentAsByteArray(flowFiles.get(0)); + final String asJson = new String(content); + if (asJson.startsWith("[")) { + final List parsed; + parsed = mapper.readValue(asJson, List.class); + assertEquals(keyCount, parsed.size(), "Too many keys"); + return parsed; + } else { + final Map parsed; + parsed = mapper.readValue(asJson, Map.class); + assertEquals(keyCount, parsed.size(), "Too many keys"); + return parsed; + } + } +}