diff --git a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java index eef9469de9..cd3cd4f148 100644 --- a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java +++ b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java @@ -1995,6 +1995,7 @@ public class DataTypeUtils { case Types.NVARCHAR: case Types.OTHER: case Types.SQLXML: + case Types.CLOB: return RecordFieldType.STRING.getDataType(); case Types.TIME: return RecordFieldType.TIME.getDataType(); @@ -2002,6 +2003,9 @@ public class DataTypeUtils { return RecordFieldType.TIMESTAMP.getDataType(); case Types.ARRAY: return RecordFieldType.ARRAY.getDataType(); + case Types.BINARY: + case Types.BLOB: + return RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BYTE.getDataType()); case Types.STRUCT: return RecordFieldType.RECORD.getDataType(); default: diff --git a/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java b/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java index a4936094bc..4745abe713 100644 --- a/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java +++ b/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java @@ -362,6 +362,12 @@ public class TestDataTypeUtils { assertEquals(Types.NUMERIC, DataTypeUtils.getSQLTypeValue(RecordFieldType.DECIMAL.getDecimalDataType(30, 10))); } + @Test + public void testGetDataTypeFromSQLTypeValue() { + assertEquals(RecordFieldType.STRING.getDataType(), DataTypeUtils.getDataTypeFromSQLTypeValue(Types.CLOB)); + assertEquals(RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BYTE.getDataType()), DataTypeUtils.getDataTypeFromSQLTypeValue(Types.BLOB)); + } + @Test public void testChooseDataTypeWhenExpectedIsBigDecimal() { // GIVEN diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java index 9b2251fd6d..91407ea23e 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java @@ -59,9 +59,12 @@ import org.apache.nifi.serialization.record.RecordSchema; import org.apache.nifi.serialization.record.util.DataTypeUtils; import org.apache.nifi.serialization.record.util.IllegalTypeConversionException; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.sql.BatchUpdateException; +import java.sql.Clob; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.Date; @@ -721,10 +724,31 @@ public class PutDatabaseRecord extends AbstractProcessor { try { DataType targetDataType = DataTypeUtils.getDataTypeFromSQLTypeValue(sqlType); if (targetDataType != null) { - currentValue = DataTypeUtils.convertType( - currentValue, - targetDataType, - fieldName); + if (sqlType == Types.BLOB || sqlType == Types.BINARY) { + if (currentValue instanceof Object[]) { + // Convert Object[Byte] arrays to byte[] + Object[] src = (Object[]) currentValue; + if (src.length > 0) { + if (!(src[0] instanceof Byte)) { + throw new IllegalTypeConversionException("Cannot convert value " + currentValue + " to BLOB/BINARY"); + } + } + byte[] dest = new byte[src.length]; + for (int j = 0; j < src.length; j++) { + dest[j] = (Byte) src[j]; + } + currentValue = dest; + } else if (currentValue instanceof String) { + currentValue = ((String) currentValue).getBytes(StandardCharsets.UTF_8); + } else if (currentValue != null && !(currentValue instanceof byte[])) { + throw new IllegalTypeConversionException("Cannot convert value " + currentValue + " to BLOB/BINARY"); + } + } else { + currentValue = DataTypeUtils.convertType( + currentValue, + targetDataType, + fieldName); + } } } catch (IllegalTypeConversionException itce) { // If the field and column types don't match or the value can't otherwise be converted to the column datatype, @@ -740,15 +764,15 @@ public class PutDatabaseRecord extends AbstractProcessor { // If DELETE type, insert the object twice because of the null check (see generateDelete for details) if (DELETE_TYPE.equalsIgnoreCase(statementType)) { - ps.setObject(i * 2 + 1, currentValue, sqlType); - ps.setObject(i * 2 + 2, currentValue, sqlType); + setParameter(ps, i * 2 + 1, currentValue, fieldSqlType, sqlType); + setParameter(ps, i * 2 + 2, currentValue, fieldSqlType, sqlType); } else if (UPSERT_TYPE.equalsIgnoreCase(statementType)) { final int timesToAddObjects = databaseAdapter.getTimesToAddColumnObjectsForUpsert(); for (int j = 0; j < timesToAddObjects; j++) { - ps.setObject(i + (fieldIndexes.size() * j) + 1, currentValue, sqlType); + setParameter(ps, i + (fieldIndexes.size() * j) + 1, currentValue, fieldSqlType, sqlType); } } else { - ps.setObject(i + 1, currentValue, sqlType); + setParameter(ps, i + 1, currentValue, fieldSqlType, sqlType); } } @@ -776,6 +800,60 @@ public class PutDatabaseRecord extends AbstractProcessor { } } + private void setParameter(PreparedStatement ps, int index, Object value, int fieldSqlType, int sqlType) throws IOException { + if (sqlType == Types.BLOB) { + // Convert Byte[] or String (anything that has been converted to byte[]) into BLOB + if (fieldSqlType == Types.ARRAY || fieldSqlType == Types.VARCHAR) { + if (!(value instanceof byte[])) { + if (value == null) { + try { + ps.setNull(index, Types.BLOB); + return; + } catch (SQLException e) { + throw new IOException("Unable to setNull() on prepared statement" , e); + } + } else { + throw new IOException("Expected BLOB to be of type byte[] but is instead " + value.getClass().getName()); + } + } + byte[] byteArray = (byte[]) value; + try (InputStream inputStream = new ByteArrayInputStream(byteArray)) { + ps.setBlob(index, inputStream); + } catch (SQLException e) { + throw new IOException("Unable to parse binary data " + value, e.getCause()); + } + } else { + try (InputStream inputStream = new ByteArrayInputStream(value.toString().getBytes(StandardCharsets.UTF_8))) { + ps.setBlob(index, inputStream); + } catch (IOException | SQLException e) { + throw new IOException("Unable to parse binary data " + value, e.getCause()); + } + } + } else if (sqlType == Types.CLOB) { + if (value == null) { + try { + ps.setNull(index, Types.CLOB); + } catch (SQLException e) { + throw new IOException("Unable to setNull() on prepared statement", e); + } + } else { + try { + Clob clob = ps.getConnection().createClob(); + clob.setString(1, value.toString()); + ps.setClob(index, clob); + } catch (SQLException e) { + throw new IOException("Unable to parse data as CLOB/String " + value, e.getCause()); + } + } + } else { + try { + ps.setObject(index, value, sqlType); + } catch (SQLException e) { + throw new IOException("Unable to setObject() with value " + value + " at index " + index + " of type " + sqlType , e); + } + } + } + private List getDataRecords(final Record outerRecord) { if (dataRecordPath == null) { return Collections.singletonList(outerRecord); diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy index e3c8e4c5b0..d6eebec12c 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy @@ -38,6 +38,8 @@ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 +import java.sql.Blob +import java.sql.Clob import java.sql.Connection import java.sql.Date import java.sql.DriverManager @@ -150,7 +152,6 @@ class TestPutDatabaseRecord { false, ['id'] as Set, '' - ] as PutDatabaseRecord.TableSchema runner.setProperty(PutDatabaseRecord.TRANSLATE_FIELD_NAMES, 'false') @@ -1451,4 +1452,176 @@ class TestPutDatabaseRecord { stmt.close() conn.close() } + + @Test + void testInsertWithBlobClob() throws Exception { + String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + + "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))" + + recreateTable(createTableWithBlob) + final MockRecordParser parser = new MockRecordParser() + runner.addControllerService("parser", parser) + runner.enableControllerService(parser) + + byte[] bytes = "BLOB".getBytes() + Byte[] blobRecordValue = new Byte[bytes.length] + (0 .. (bytes.length-1)).each { i -> blobRecordValue[i] = bytes[i].longValue() } + + parser.addSchemaField("id", RecordFieldType.INT) + parser.addSchemaField("name", RecordFieldType.STRING) + parser.addSchemaField("code", RecordFieldType.INT) + parser.addSchemaField("content", RecordFieldType.ARRAY) + + parser.addRecord(1, 'rec1', 101, blobRecordValue) + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) + runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') + + runner.enqueue(new byte[0]) + runner.run() + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) + final Connection conn = dbcp.getConnection() + final Statement stmt = conn.createStatement() + final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') + assertTrue(rs.next()) + assertEquals(1, rs.getInt(1)) + Clob clob = rs.getClob(2) + assertNotNull(clob) + char[] clobText = new char[5] + int numBytes = clob.characterStream.read(clobText) + assertEquals(4, numBytes) + // Ignore last character, it's meant to ensure that only 4 bytes were read even though the buffer is 5 bytes + assertEquals('rec1', new String(clobText).substring(0,4)) + Blob blob = rs.getBlob(3) + assertEquals("BLOB", new String(blob.getBytes(1, blob.length() as int))) + assertEquals(101, rs.getInt(4)) + + stmt.close() + conn.close() + } + + @Test + void testInsertWithBlobClobObjectArraySource() throws Exception { + String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + + "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))" + + recreateTable(createTableWithBlob) + final MockRecordParser parser = new MockRecordParser() + runner.addControllerService("parser", parser) + runner.enableControllerService(parser) + + byte[] bytes = "BLOB".getBytes() + Object[] blobRecordValue = new Object[bytes.length] + (0 .. (bytes.length-1)).each { i -> blobRecordValue[i] = bytes[i] } + + parser.addSchemaField("id", RecordFieldType.INT) + parser.addSchemaField("name", RecordFieldType.STRING) + parser.addSchemaField("code", RecordFieldType.INT) + parser.addSchemaField("content", RecordFieldType.ARRAY) + + parser.addRecord(1, 'rec1', 101, blobRecordValue) + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) + runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') + + runner.enqueue(new byte[0]) + runner.run() + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) + final Connection conn = dbcp.getConnection() + final Statement stmt = conn.createStatement() + final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') + assertTrue(rs.next()) + assertEquals(1, rs.getInt(1)) + Clob clob = rs.getClob(2) + assertNotNull(clob) + char[] clobText = new char[5] + int numBytes = clob.characterStream.read(clobText) + assertEquals(4, numBytes) + // Ignore last character, it's meant to ensure that only 4 bytes were read even though the buffer is 5 bytes + assertEquals('rec1', new String(clobText).substring(0,4)) + Blob blob = rs.getBlob(3) + assertEquals("BLOB", new String(blob.getBytes(1, blob.length() as int))) + assertEquals(101, rs.getInt(4)) + + stmt.close() + conn.close() + } + + @Test + void testInsertWithBlobStringSource() throws Exception { + String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + + "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))" + + recreateTable(createTableWithBlob) + final MockRecordParser parser = new MockRecordParser() + runner.addControllerService("parser", parser) + runner.enableControllerService(parser) + + parser.addSchemaField("id", RecordFieldType.INT) + parser.addSchemaField("name", RecordFieldType.STRING) + parser.addSchemaField("code", RecordFieldType.INT) + parser.addSchemaField("content", RecordFieldType.STRING) + + parser.addRecord(1, 'rec1', 101, 'BLOB') + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) + runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') + + runner.enqueue(new byte[0]) + runner.run() + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 1) + final Connection conn = dbcp.getConnection() + final Statement stmt = conn.createStatement() + final ResultSet rs = stmt.executeQuery('SELECT * FROM PERSONS') + assertTrue(rs.next()) + assertEquals(1, rs.getInt(1)) + Clob clob = rs.getClob(2) + assertNotNull(clob) + char[] clobText = new char[5] + int numBytes = clob.characterStream.read(clobText) + assertEquals(4, numBytes) + // Ignore last character, it's meant to ensure that only 4 bytes were read even though the buffer is 5 bytes + assertEquals('rec1', new String(clobText).substring(0,4)) + Blob blob = rs.getBlob(3) + assertEquals("BLOB", new String(blob.getBytes(1, blob.length() as int))) + assertEquals(101, rs.getInt(4)) + + stmt.close() + conn.close() + } + + @Test + void testInsertWithBlobIntegerArraySource() throws Exception { + String createTableWithBlob = "CREATE TABLE PERSONS (id integer primary key, name clob," + + "content blob, code integer CONSTRAINT CODE_RANGE CHECK (code >= 0 AND code < 1000))" + + recreateTable(createTableWithBlob) + final MockRecordParser parser = new MockRecordParser() + runner.addControllerService("parser", parser) + runner.enableControllerService(parser) + + parser.addSchemaField("id", RecordFieldType.INT) + parser.addSchemaField("name", RecordFieldType.STRING) + parser.addSchemaField("code", RecordFieldType.INT) + parser.addSchemaField("content", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.INT.getDataType()).getFieldType()) + + parser.addRecord(1, 'rec1', 101, [1,2,3] as Integer[]) + + runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') + runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) + runner.setProperty(PutDatabaseRecord.TABLE_NAME, 'PERSONS') + + runner.enqueue(new byte[0]) + runner.run() + + runner.assertTransferCount(PutDatabaseRecord.REL_SUCCESS, 0) + runner.assertTransferCount(PutDatabaseRecord.REL_RETRY, 0) + runner.assertTransferCount(PutDatabaseRecord.REL_FAILURE, 1) + } }