NIFI-8368: If decimal scale > precision, set precision = scale

This closes #4938

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
Matthew Burgess 2021-03-25 17:10:41 -04:00 committed by exceptionfactory
parent bff3e94c01
commit a5dbf56114
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
4 changed files with 86 additions and 54 deletions

View File

@ -220,7 +220,7 @@ public class ResultSetRecordSet implements RecordSet, Closeable {
return RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BYTE.getDataType()); return RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BYTE.getDataType());
case Types.NUMERIC: case Types.NUMERIC:
case Types.DECIMAL: case Types.DECIMAL:
final int decimalPrecision; int decimalPrecision;
final int decimalScale; final int decimalScale;
final int resultSetPrecision = rs.getMetaData().getPrecision(columnIndex); final int resultSetPrecision = rs.getMetaData().getPrecision(columnIndex);
final int resultSetScale = rs.getMetaData().getScale(columnIndex); final int resultSetScale = rs.getMetaData().getScale(columnIndex);
@ -239,6 +239,11 @@ public class ResultSetRecordSet implements RecordSet, Closeable {
// Default scale is used to preserve decimals in such case. // Default scale is used to preserve decimals in such case.
decimalScale = resultSetScale > 0 ? resultSetScale : defaultScale; decimalScale = resultSetScale > 0 ? resultSetScale : defaultScale;
} }
// Scale can be bigger than precision in some cases (Oracle, e.g.) If this is the case, assume precision refers to the number of
// decimal digits and thus precision = scale
if (decimalScale > decimalPrecision) {
decimalPrecision = decimalScale;
}
return RecordFieldType.DECIMAL.getDecimalDataType(decimalPrecision, decimalScale); return RecordFieldType.DECIMAL.getDecimalDataType(decimalPrecision, decimalScale);
case Types.OTHER: { case Types.OTHER: {
// If we have no records to inspect, we can't really know its schema so we simply use the default data type. // If we have no records to inspect, we can't really know its schema so we simply use the default data type.

View File

@ -61,6 +61,7 @@ public class ResultSetRecordSetTest {
private static final String COLUMN_NAME_BIG_DECIMAL_2 = "bigDecimal2"; private static final String COLUMN_NAME_BIG_DECIMAL_2 = "bigDecimal2";
private static final String COLUMN_NAME_BIG_DECIMAL_3 = "bigDecimal3"; private static final String COLUMN_NAME_BIG_DECIMAL_3 = "bigDecimal3";
private static final String COLUMN_NAME_BIG_DECIMAL_4 = "bigDecimal4"; private static final String COLUMN_NAME_BIG_DECIMAL_4 = "bigDecimal4";
private static final String COLUMN_NAME_BIG_DECIMAL_5 = "bigDecimal5";
private static final Object[][] COLUMNS = new Object[][] { private static final Object[][] COLUMNS = new Object[][] {
// column number; column label / name / schema field; column type; schema data type; // column number; column label / name / schema field; column type; schema data type;
@ -81,6 +82,7 @@ public class ResultSetRecordSetTest {
{15, COLUMN_NAME_BIG_DECIMAL_2, Types.NUMERIC, RecordFieldType.DECIMAL.getDecimalDataType(4, 0)}, {15, COLUMN_NAME_BIG_DECIMAL_2, Types.NUMERIC, RecordFieldType.DECIMAL.getDecimalDataType(4, 0)},
{16, COLUMN_NAME_BIG_DECIMAL_3, Types.JAVA_OBJECT, RecordFieldType.DECIMAL.getDecimalDataType(501, 1)}, {16, COLUMN_NAME_BIG_DECIMAL_3, Types.JAVA_OBJECT, RecordFieldType.DECIMAL.getDecimalDataType(501, 1)},
{17, COLUMN_NAME_BIG_DECIMAL_4, Types.DECIMAL, RecordFieldType.DECIMAL.getDecimalDataType(10, 3)}, {17, COLUMN_NAME_BIG_DECIMAL_4, Types.DECIMAL, RecordFieldType.DECIMAL.getDecimalDataType(10, 3)},
{18, COLUMN_NAME_BIG_DECIMAL_5, Types.DECIMAL, RecordFieldType.DECIMAL.getDecimalDataType(3, 10)},
}; };
@Mock @Mock
@ -189,6 +191,7 @@ public class ResultSetRecordSetTest {
final BigDecimal bigDecimal2Value = new BigDecimal("1234"); final BigDecimal bigDecimal2Value = new BigDecimal("1234");
final BigDecimal bigDecimal3Value = new BigDecimal("1234567890.1"); final BigDecimal bigDecimal3Value = new BigDecimal("1234567890.1");
final BigDecimal bigDecimal4Value = new BigDecimal("1234567.089"); final BigDecimal bigDecimal4Value = new BigDecimal("1234567.089");
final BigDecimal bigDecimal5Value = new BigDecimal("0.1234567");
when(resultSet.getObject(COLUMN_NAME_VARCHAR)).thenReturn(varcharValue); when(resultSet.getObject(COLUMN_NAME_VARCHAR)).thenReturn(varcharValue);
when(resultSet.getObject(COLUMN_NAME_BIGINT)).thenReturn(bigintValue); when(resultSet.getObject(COLUMN_NAME_BIGINT)).thenReturn(bigintValue);
@ -207,6 +210,7 @@ public class ResultSetRecordSetTest {
when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_2)).thenReturn(bigDecimal2Value); when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_2)).thenReturn(bigDecimal2Value);
when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_3)).thenReturn(bigDecimal3Value); when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_3)).thenReturn(bigDecimal3Value);
when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_4)).thenReturn(bigDecimal4Value); when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_4)).thenReturn(bigDecimal4Value);
when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_5)).thenReturn(bigDecimal5Value);
// when // when
ResultSetRecordSet testSubject = new ResultSetRecordSet(resultSet, recordSchema); ResultSetRecordSet testSubject = new ResultSetRecordSet(resultSet, recordSchema);
@ -234,6 +238,7 @@ public class ResultSetRecordSetTest {
assertEquals(bigDecimal2Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_2)); assertEquals(bigDecimal2Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_2));
assertEquals(bigDecimal3Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_3)); assertEquals(bigDecimal3Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_3));
assertEquals(bigDecimal4Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_4)); assertEquals(bigDecimal4Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_4));
assertEquals(bigDecimal5Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_5));
} }
private ResultSet givenResultSetForOther() throws SQLException { private ResultSet givenResultSetForOther() throws SQLException {
@ -261,7 +266,16 @@ public class ResultSetRecordSetTest {
assertNotNull(resultSchema); assertNotNull(resultSchema);
for (final Object[] column : COLUMNS) { for (final Object[] column : COLUMNS) {
assertEquals("For column " + column[0] + " the converted type is not matching", column[3], resultSchema.getField((Integer) column[0] - 1).getDataType()); // The DECIMAL column with scale larger than precision will not match so verify that instead
DataType actualDataType = resultSchema.getField((Integer) column[0] - 1).getDataType();
DataType expectedDataType = (DataType) column[3];
if (expectedDataType.equals(RecordFieldType.DECIMAL.getDecimalDataType(3, 10))) {
DecimalDataType decimalDataType = (DecimalDataType) expectedDataType;
if (decimalDataType.getScale() > decimalDataType.getPrecision()) {
expectedDataType = RecordFieldType.DECIMAL.getDecimalDataType(decimalDataType.getScale(), decimalDataType.getScale());
}
}
assertEquals("For column " + column[0] + " the converted type is not matching", expectedDataType, actualDataType);
} }
} }
} }

View File

@ -16,38 +16,26 @@
*/ */
package org.apache.nifi.util.db; package org.apache.nifi.util.db;
import static java.sql.Types.ARRAY; import org.apache.avro.LogicalTypes;
import static java.sql.Types.BIGINT; import org.apache.avro.Schema;
import static java.sql.Types.BINARY; import org.apache.avro.SchemaBuilder;
import static java.sql.Types.BIT; import org.apache.avro.SchemaBuilder.BaseTypeBuilder;
import static java.sql.Types.BLOB; import org.apache.avro.SchemaBuilder.FieldAssembler;
import static java.sql.Types.BOOLEAN; import org.apache.avro.SchemaBuilder.NullDefault;
import static java.sql.Types.CHAR; import org.apache.avro.SchemaBuilder.UnionAccumulator;
import static java.sql.Types.CLOB; import org.apache.avro.UnresolvedUnionException;
import static java.sql.Types.DATE; import org.apache.avro.file.CodecFactory;
import static java.sql.Types.DECIMAL; import org.apache.avro.file.DataFileWriter;
import static java.sql.Types.DOUBLE; import org.apache.avro.generic.GenericData;
import static java.sql.Types.FLOAT; import org.apache.avro.generic.GenericDatumWriter;
import static java.sql.Types.INTEGER; import org.apache.avro.generic.GenericRecord;
import static java.sql.Types.LONGNVARCHAR; import org.apache.avro.io.DatumWriter;
import static java.sql.Types.LONGVARBINARY; import org.apache.commons.lang3.StringUtils;
import static java.sql.Types.LONGVARCHAR; import org.apache.commons.lang3.exception.ExceptionUtils;
import static java.sql.Types.NCHAR; import org.apache.nifi.avro.AvroTypeUtil;
import static java.sql.Types.NCLOB; import org.apache.nifi.serialization.record.util.DataTypeUtils;
import static java.sql.Types.NUMERIC;
import static java.sql.Types.NVARCHAR;
import static java.sql.Types.OTHER;
import static java.sql.Types.REAL;
import static java.sql.Types.ROWID;
import static java.sql.Types.SMALLINT;
import static java.sql.Types.SQLXML;
import static java.sql.Types.TIME;
import static java.sql.Types.TIMESTAMP;
import static java.sql.Types.TIMESTAMP_WITH_TIMEZONE;
import static java.sql.Types.TINYINT;
import static java.sql.Types.VARBINARY;
import static java.sql.Types.VARCHAR;
import javax.xml.bind.DatatypeConverter;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -84,26 +72,37 @@ import java.util.function.Function;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import org.apache.avro.LogicalTypes; import static java.sql.Types.ARRAY;
import org.apache.avro.Schema; import static java.sql.Types.BIGINT;
import org.apache.avro.SchemaBuilder; import static java.sql.Types.BINARY;
import org.apache.avro.SchemaBuilder.BaseTypeBuilder; import static java.sql.Types.BIT;
import org.apache.avro.SchemaBuilder.FieldAssembler; import static java.sql.Types.BLOB;
import org.apache.avro.SchemaBuilder.NullDefault; import static java.sql.Types.BOOLEAN;
import org.apache.avro.SchemaBuilder.UnionAccumulator; import static java.sql.Types.CHAR;
import org.apache.avro.file.CodecFactory; import static java.sql.Types.CLOB;
import org.apache.avro.UnresolvedUnionException; import static java.sql.Types.DATE;
import org.apache.avro.file.DataFileWriter; import static java.sql.Types.DECIMAL;
import org.apache.avro.generic.GenericData; import static java.sql.Types.DOUBLE;
import org.apache.avro.generic.GenericDatumWriter; import static java.sql.Types.FLOAT;
import org.apache.avro.generic.GenericRecord; import static java.sql.Types.INTEGER;
import org.apache.avro.io.DatumWriter; import static java.sql.Types.LONGNVARCHAR;
import org.apache.commons.lang3.exception.ExceptionUtils; import static java.sql.Types.LONGVARBINARY;
import org.apache.commons.lang3.StringUtils; import static java.sql.Types.LONGVARCHAR;
import org.apache.nifi.avro.AvroTypeUtil; import static java.sql.Types.NCHAR;
import org.apache.nifi.serialization.record.util.DataTypeUtils; import static java.sql.Types.NCLOB;
import static java.sql.Types.NUMERIC;
import javax.xml.bind.DatatypeConverter; import static java.sql.Types.NVARCHAR;
import static java.sql.Types.OTHER;
import static java.sql.Types.REAL;
import static java.sql.Types.ROWID;
import static java.sql.Types.SMALLINT;
import static java.sql.Types.SQLXML;
import static java.sql.Types.TIME;
import static java.sql.Types.TIMESTAMP;
import static java.sql.Types.TIMESTAMP_WITH_TIMEZONE;
import static java.sql.Types.TINYINT;
import static java.sql.Types.VARBINARY;
import static java.sql.Types.VARCHAR;
/** /**
* JDBC / SQL common functions. * JDBC / SQL common functions.
@ -576,7 +575,7 @@ public class JdbcCommon {
case DECIMAL: case DECIMAL:
case NUMERIC: case NUMERIC:
if (options.useLogicalTypes) { if (options.useLogicalTypes) {
final int decimalPrecision; int decimalPrecision;
final int decimalScale; final int decimalScale;
if (meta.getPrecision(i) > 0) { if (meta.getPrecision(i) > 0) {
// When database returns a certain precision, we can rely on that. // When database returns a certain precision, we can rely on that.
@ -593,6 +592,11 @@ public class JdbcCommon {
// Default scale is used to preserve decimals in such case. // Default scale is used to preserve decimals in such case.
decimalScale = meta.getScale(i) > 0 ? meta.getScale(i) : options.defaultScale; decimalScale = meta.getScale(i) > 0 ? meta.getScale(i) : options.defaultScale;
} }
// Scale can be bigger than precision in some cases (Oracle, e.g.) If this is the case, assume precision refers to the number of
// decimal digits and thus precision = scale
if (decimalScale > decimalPrecision) {
decimalPrecision = decimalScale;
}
final LogicalTypes.Decimal decimal = LogicalTypes.decimal(decimalPrecision, decimalScale); final LogicalTypes.Decimal decimal = LogicalTypes.decimal(decimalPrecision, decimalScale);
addNullableField(builder, columnName, addNullableField(builder, columnName,
u -> u.type(decimal.addToSchema(SchemaBuilder.builder().bytesType()))); u -> u.type(decimal.addToSchema(SchemaBuilder.builder().bytesType())));

View File

@ -428,6 +428,15 @@ public class TestJdbcCommon {
testConvertToAvroStreamForBigDecimal(bigDecimal, dbPrecision, 24, 24, expectedScale); testConvertToAvroStreamForBigDecimal(bigDecimal, dbPrecision, 24, 24, expectedScale);
} }
@Test
public void testConvertToAvroStreamForBigDecimalWithScaleLargerThanPrecision() throws SQLException, IOException {
final int expectedScale = 6; // Scale can be larger than precision in Oracle
final int dbPrecision = 5;
final BigDecimal bigDecimal = new BigDecimal("0.000123", new MathContext(dbPrecision));
// If db doesn't return a precision, default precision should be used.
testConvertToAvroStreamForBigDecimal(bigDecimal, dbPrecision, 10, expectedScale, expectedScale);
}
private void testConvertToAvroStreamForBigDecimal(BigDecimal bigDecimal, int dbPrecision, int defaultPrecision, int expectedPrecision, int expectedScale) throws SQLException, IOException { private void testConvertToAvroStreamForBigDecimal(BigDecimal bigDecimal, int dbPrecision, int defaultPrecision, int expectedPrecision, int expectedScale) throws SQLException, IOException {
final ResultSetMetaData metadata = mock(ResultSetMetaData.class); final ResultSetMetaData metadata = mock(ResultSetMetaData.class);