NIFI-6409: Fixed issue with PutDatabaseRecord when driver doesn't support setObject() without type

Signed-off-by: Pierre Villard <pierre.villard.fr@gmail.com>

This closes #3585.
This commit is contained in:
Matthew Burgess 2019-07-17 12:22:39 -04:00 committed by Pierre Villard
parent 3fb4454375
commit d8388c1887
No known key found for this signature in database
GPG Key ID: BEE1599F0726E9CD
3 changed files with 73 additions and 6 deletions

View File

@ -41,6 +41,7 @@ import java.sql.Clob;
import java.sql.Date; import java.sql.Date;
import java.sql.Time; import java.sql.Time;
import java.sql.Timestamp; import java.sql.Timestamp;
import java.sql.Types;
import java.text.DateFormat; import java.text.DateFormat;
import java.text.ParseException; import java.text.ParseException;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
@ -1607,6 +1608,55 @@ public class DataTypeUtils {
} }
} }
/**
* Converts the specified field data type into a java.sql.Types constant (INTEGER = 4, e.g.)
*
* @param dataType the DataType to be converted
* @return the SQL type corresponding to the specified RecordFieldType
*/
public static int getSQLTypeValue(final DataType dataType) {
if (dataType == null) {
return Types.NULL;
}
RecordFieldType fieldType = dataType.getFieldType();
switch (fieldType) {
case BIGINT:
case LONG:
return Types.BIGINT;
case BOOLEAN:
return Types.BOOLEAN;
case BYTE:
return Types.TINYINT;
case CHAR:
return Types.CHAR;
case DATE:
return Types.DATE;
case DOUBLE:
return Types.DOUBLE;
case FLOAT:
return Types.FLOAT;
case INT:
return Types.INTEGER;
case SHORT:
return Types.SMALLINT;
case STRING:
return Types.VARCHAR;
case TIME:
return Types.TIME;
case TIMESTAMP:
return Types.TIMESTAMP;
case ARRAY:
return Types.ARRAY;
case MAP:
case RECORD:
return Types.STRUCT;
case CHOICE:
throw new IllegalTypeConversionException("Cannot convert CHOICE, type must be explicit");
default:
throw new IllegalTypeConversionException("Cannot convert unknown type " + fieldType.name());
}
}
public static boolean isScalarValue(final DataType dataType, final Object value) { public static boolean isScalarValue(final DataType dataType, final Object value) {
final RecordFieldType fieldType = dataType.getFieldType(); final RecordFieldType fieldType = dataType.getFieldType();

View File

@ -50,9 +50,11 @@ import org.apache.nifi.processor.util.pattern.RoutingResult;
import org.apache.nifi.serialization.MalformedRecordException; import org.apache.nifi.serialization.MalformedRecordException;
import org.apache.nifi.serialization.RecordReader; import org.apache.nifi.serialization.RecordReader;
import org.apache.nifi.serialization.RecordReaderFactory; import org.apache.nifi.serialization.RecordReaderFactory;
import org.apache.nifi.serialization.record.DataType;
import org.apache.nifi.serialization.record.Record; import org.apache.nifi.serialization.record.Record;
import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordField;
import org.apache.nifi.serialization.record.RecordSchema; import org.apache.nifi.serialization.record.RecordSchema;
import org.apache.nifi.serialization.record.util.DataTypeUtils;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -687,26 +689,35 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor {
while ((currentRecord = recordParser.nextRecord()) != null) { while ((currentRecord = recordParser.nextRecord()) != null) {
Object[] values = currentRecord.getValues(); Object[] values = currentRecord.getValues();
List<DataType> dataTypes = currentRecord.getSchema().getDataTypes();
if (values != null) { if (values != null) {
if (fieldIndexes != null) { if (fieldIndexes != null) {
for (int i = 0; i < fieldIndexes.size(); i++) { for (int i = 0; i < fieldIndexes.size(); i++) {
final int currentFieldIndex = fieldIndexes.get(i);
final Object currentValue = values[currentFieldIndex];
final DataType dataType = dataTypes.get(currentFieldIndex);
final int sqlType = DataTypeUtils.getSQLTypeValue(dataType);
// If DELETE type, insert the object twice because of the null check (see generateDelete for details) // If DELETE type, insert the object twice because of the null check (see generateDelete for details)
if (DELETE_TYPE.equalsIgnoreCase(statementType)) { if (DELETE_TYPE.equalsIgnoreCase(statementType)) {
ps.setObject(i * 2 + 1, values[fieldIndexes.get(i)]); ps.setObject(i * 2 + 1, currentValue, sqlType);
ps.setObject(i * 2 + 2, values[fieldIndexes.get(i)]); ps.setObject(i * 2 + 2, currentValue, sqlType);
} else { } else {
ps.setObject(i + 1, values[fieldIndexes.get(i)]); ps.setObject(i + 1, currentValue, sqlType);
} }
} }
} else { } else {
// If there's no index map, assume all values are included and set them in order // If there's no index map, assume all values are included and set them in order
for (int i = 0; i < values.length; i++) { for (int i = 0; i < values.length; i++) {
final Object currentValue = values[i];
final DataType dataType = dataTypes.get(i);
final int sqlType = DataTypeUtils.getSQLTypeValue(dataType);
// If DELETE type, insert the object twice because of the null check (see generateDelete for details) // If DELETE type, insert the object twice because of the null check (see generateDelete for details)
if (DELETE_TYPE.equalsIgnoreCase(statementType)) { if (DELETE_TYPE.equalsIgnoreCase(statementType)) {
ps.setObject(i * 2 + 1, values[i]); ps.setObject(i * 2 + 1, currentValue, sqlType);
ps.setObject(i * 2 + 2, values[i]); ps.setObject(i * 2 + 2, currentValue, sqlType);
} else { } else {
ps.setObject(i + 1, values[i]); ps.setObject(i + 1, currentValue, sqlType);
} }
} }
} }

View File

@ -49,6 +49,7 @@ import java.util.function.Supplier
import static org.junit.Assert.assertEquals import static org.junit.Assert.assertEquals
import static org.junit.Assert.assertFalse import static org.junit.Assert.assertFalse
import static org.junit.Assert.assertNotNull import static org.junit.Assert.assertNotNull
import static org.junit.Assert.assertNull
import static org.junit.Assert.assertTrue import static org.junit.Assert.assertTrue
import static org.junit.Assert.fail import static org.junit.Assert.fail
import static org.mockito.ArgumentMatchers.anyMap import static org.mockito.ArgumentMatchers.anyMap
@ -239,6 +240,7 @@ class TestPutDatabaseRecord {
parser.addRecord(2, 'rec2', 102) parser.addRecord(2, 'rec2', 102)
parser.addRecord(3, 'rec3', 103) parser.addRecord(3, 'rec3', 103)
parser.addRecord(4, 'rec4', 104) parser.addRecord(4, 'rec4', 104)
parser.addRecord(5, null, 105)
runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser') runner.setProperty(PutDatabaseRecord.RECORD_READER_FACTORY, 'parser')
runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE) runner.setProperty(PutDatabaseRecord.STATEMENT_TYPE, PutDatabaseRecord.INSERT_TYPE)
@ -267,6 +269,10 @@ class TestPutDatabaseRecord {
assertEquals(4, rs.getInt(1)) assertEquals(4, rs.getInt(1))
assertEquals('rec4', rs.getString(2)) assertEquals('rec4', rs.getString(2))
assertEquals(104, rs.getInt(3)) assertEquals(104, rs.getInt(3))
assertTrue(rs.next())
assertEquals(5, rs.getInt(1))
assertNull(rs.getString(2))
assertEquals(105, rs.getInt(3))
assertFalse(rs.next()) assertFalse(rs.next())
stmt.close() stmt.close()