diff --git a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java index 7fff72d0e2..bab455a49e 100644 --- a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java +++ b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java @@ -41,6 +41,7 @@ import java.sql.Clob; import java.sql.Date; import java.sql.Time; import java.sql.Timestamp; +import java.sql.Types; import java.text.DateFormat; import java.text.ParseException; import java.text.SimpleDateFormat; @@ -1607,6 +1608,55 @@ public class DataTypeUtils { } } + /** + * Converts the specified field data type into a java.sql.Types constant (INTEGER = 4, e.g.) + * + * @param dataType the DataType to be converted + * @return the SQL type corresponding to the specified RecordFieldType + */ + public static int getSQLTypeValue(final DataType dataType) { + if (dataType == null) { + return Types.NULL; + } + RecordFieldType fieldType = dataType.getFieldType(); + switch (fieldType) { + case BIGINT: + case LONG: + return Types.BIGINT; + case BOOLEAN: + return Types.BOOLEAN; + case BYTE: + return Types.TINYINT; + case CHAR: + return Types.CHAR; + case DATE: + return Types.DATE; + case DOUBLE: + return Types.DOUBLE; + case FLOAT: + return Types.FLOAT; + case INT: + return Types.INTEGER; + case SHORT: + return Types.SMALLINT; + case STRING: + return Types.VARCHAR; + case TIME: + return Types.TIME; + case TIMESTAMP: + return Types.TIMESTAMP; + case ARRAY: + return Types.ARRAY; + case MAP: + case RECORD: + return Types.STRUCT; + case CHOICE: + throw new IllegalTypeConversionException("Cannot convert CHOICE, type must be explicit"); + default: + throw new IllegalTypeConversionException("Cannot convert unknown type " + fieldType.name()); + } + } + public static boolean isScalarValue(final DataType dataType, final Object value) { final RecordFieldType fieldType = dataType.getFieldType(); diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java index 3046cbf4dc..8b4ad780d1 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java @@ -50,9 +50,11 @@ import org.apache.nifi.processor.util.pattern.RoutingResult; import org.apache.nifi.serialization.MalformedRecordException; import org.apache.nifi.serialization.RecordReader; import org.apache.nifi.serialization.RecordReaderFactory; +import org.apache.nifi.serialization.record.DataType; import org.apache.nifi.serialization.record.Record; import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.serialization.record.util.DataTypeUtils; import java.io.IOException; import java.io.InputStream; @@ -687,26 +689,35 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { while ((currentRecord = recordParser.nextRecord()) != null) { Object[] values = currentRecord.getValues(); + List dataTypes = currentRecord.getSchema().getDataTypes(); if (values != null) { if (fieldIndexes != null) { for (int i = 0; i < fieldIndexes.size(); i++) { + final int currentFieldIndex = fieldIndexes.get(i); + final Object currentValue = values[currentFieldIndex]; + final DataType dataType = dataTypes.get(currentFieldIndex); + final int sqlType = DataTypeUtils.getSQLTypeValue(dataType); + // If DELETE type, insert the object twice because of the null check (see generateDelete for details) if (DELETE_TYPE.equalsIgnoreCase(statementType)) { - ps.setObject(i * 2 + 1, values[fieldIndexes.get(i)]); - ps.setObject(i * 2 + 2, values[fieldIndexes.get(i)]); + ps.setObject(i * 2 + 1, currentValue, sqlType); + ps.setObject(i * 2 + 2, currentValue, sqlType); } else { - ps.setObject(i + 1, values[fieldIndexes.get(i)]); + ps.setObject(i + 1, currentValue, sqlType); } } } else { // If there's no index map, assume all values are included and set them in order for (int i = 0; i < values.length; i++) { + final Object currentValue = values[i]; + final DataType dataType = dataTypes.get(i); + final int sqlType = DataTypeUtils.getSQLTypeValue(dataType); // If DELETE type, insert the object twice because of the null check (see generateDelete for details) if (DELETE_TYPE.equalsIgnoreCase(statementType)) { - ps.setObject(i * 2 + 1, values[i]); - ps.setObject(i * 2 + 2, values[i]); + ps.setObject(i * 2 + 1, currentValue, sqlType); + ps.setObject(i * 2 + 2, currentValue, sqlType); } else { - ps.setObject(i + 1, values[i]); + ps.setObject(i + 1, currentValue, sqlType); } } } diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy index c3bc6ed4d4..080502cc40 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/groovy/org/apache/nifi/processors/standard/TestPutDatabaseRecord.groovy @@ -49,6 +49,7 @@ import java.util.function.Supplier import static org.junit.Assert.assertEquals import static org.junit.Assert.assertFalse import static org.junit.Assert.assertNotNull +import static org.junit.Assert.assertNull import static org.junit.Assert.assertTrue import static org.junit.Assert.fail import static org.mockito.ArgumentMatchers.anyMap @@ -239,6 +240,7 @@ class TestPutDatabaseRecord { parser.addRecord(2, 'rec2', 102) parser.addRecord(3, 'rec3', 103) parser.addRecord(4, 'rec4', 104) + parser.addRecord(5, null, 105) runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) @@ -267,6 +269,10 @@ class TestPutDatabaseRecord { assertEquals(4, rs.getInt(1)) assertEquals('rec4', rs.getString(2)) assertEquals(104, rs.getInt(3)) + assertTrue(rs.next()) + assertEquals(5, rs.getInt(1)) + assertNull(rs.getString(2)) + assertEquals(105, rs.getInt(3)) assertFalse(rs.next()) stmt.close()