NIFI-8368: If decimal scale > precision, set precision = scale

This closes #4938

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
Matthew Burgess 2021-03-25 17:10:41 -04:00 committed by exceptionfactory
parent bff3e94c01
commit a5dbf56114
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
4 changed files with 86 additions and 54 deletions

View File

@ -220,7 +220,7 @@ public class ResultSetRecordSet implements RecordSet, Closeable {
return RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BYTE.getDataType());
case Types.NUMERIC:
case Types.DECIMAL:
final int decimalPrecision;
int decimalPrecision;
final int decimalScale;
final int resultSetPrecision = rs.getMetaData().getPrecision(columnIndex);
final int resultSetScale = rs.getMetaData().getScale(columnIndex);
@ -239,6 +239,11 @@ public class ResultSetRecordSet implements RecordSet, Closeable {
// Default scale is used to preserve decimals in such case.
decimalScale = resultSetScale > 0 ? resultSetScale : defaultScale;
}
// Scale can be bigger than precision in some cases (Oracle, e.g.) If this is the case, assume precision refers to the number of
// decimal digits and thus precision = scale
if (decimalScale > decimalPrecision) {
decimalPrecision = decimalScale;
}
return RecordFieldType.DECIMAL.getDecimalDataType(decimalPrecision, decimalScale);
case Types.OTHER: {
// If we have no records to inspect, we can't really know its schema so we simply use the default data type.

View File

@ -61,6 +61,7 @@ public class ResultSetRecordSetTest {
private static final String COLUMN_NAME_BIG_DECIMAL_2 = "bigDecimal2";
private static final String COLUMN_NAME_BIG_DECIMAL_3 = "bigDecimal3";
private static final String COLUMN_NAME_BIG_DECIMAL_4 = "bigDecimal4";
private static final String COLUMN_NAME_BIG_DECIMAL_5 = "bigDecimal5";
private static final Object[][] COLUMNS = new Object[][] {
// column number; column label / name / schema field; column type; schema data type;
@ -81,6 +82,7 @@ public class ResultSetRecordSetTest {
{15, COLUMN_NAME_BIG_DECIMAL_2, Types.NUMERIC, RecordFieldType.DECIMAL.getDecimalDataType(4, 0)},
{16, COLUMN_NAME_BIG_DECIMAL_3, Types.JAVA_OBJECT, RecordFieldType.DECIMAL.getDecimalDataType(501, 1)},
{17, COLUMN_NAME_BIG_DECIMAL_4, Types.DECIMAL, RecordFieldType.DECIMAL.getDecimalDataType(10, 3)},
{18, COLUMN_NAME_BIG_DECIMAL_5, Types.DECIMAL, RecordFieldType.DECIMAL.getDecimalDataType(3, 10)},
};
@Mock
@ -189,6 +191,7 @@ public class ResultSetRecordSetTest {
final BigDecimal bigDecimal2Value = new BigDecimal("1234");
final BigDecimal bigDecimal3Value = new BigDecimal("1234567890.1");
final BigDecimal bigDecimal4Value = new BigDecimal("1234567.089");
final BigDecimal bigDecimal5Value = new BigDecimal("0.1234567");
when(resultSet.getObject(COLUMN_NAME_VARCHAR)).thenReturn(varcharValue);
when(resultSet.getObject(COLUMN_NAME_BIGINT)).thenReturn(bigintValue);
@ -207,6 +210,7 @@ public class ResultSetRecordSetTest {
when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_2)).thenReturn(bigDecimal2Value);
when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_3)).thenReturn(bigDecimal3Value);
when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_4)).thenReturn(bigDecimal4Value);
when(resultSet.getObject(COLUMN_NAME_BIG_DECIMAL_5)).thenReturn(bigDecimal5Value);
// when
ResultSetRecordSet testSubject = new ResultSetRecordSet(resultSet, recordSchema);
@ -234,6 +238,7 @@ public class ResultSetRecordSetTest {
assertEquals(bigDecimal2Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_2));
assertEquals(bigDecimal3Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_3));
assertEquals(bigDecimal4Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_4));
assertEquals(bigDecimal5Value, record.getValue(COLUMN_NAME_BIG_DECIMAL_5));
}
private ResultSet givenResultSetForOther() throws SQLException {
@ -261,7 +266,16 @@ public class ResultSetRecordSetTest {
assertNotNull(resultSchema);
for (final Object[] column : COLUMNS) {
assertEquals("For column " + column[0] + " the converted type is not matching", column[3], resultSchema.getField((Integer) column[0] - 1).getDataType());
// The DECIMAL column with scale larger than precision will not match so verify that instead
DataType actualDataType = resultSchema.getField((Integer) column[0] - 1).getDataType();
DataType expectedDataType = (DataType) column[3];
if (expectedDataType.equals(RecordFieldType.DECIMAL.getDecimalDataType(3, 10))) {
DecimalDataType decimalDataType = (DecimalDataType) expectedDataType;
if (decimalDataType.getScale() > decimalDataType.getPrecision()) {
expectedDataType = RecordFieldType.DECIMAL.getDecimalDataType(decimalDataType.getScale(), decimalDataType.getScale());
}
}
assertEquals("For column " + column[0] + " the converted type is not matching", expectedDataType, actualDataType);
}
}
}

View File

@ -16,38 +16,26 @@
*/
package org.apache.nifi.util.db;
import static java.sql.Types.ARRAY;
import static java.sql.Types.BIGINT;
import static java.sql.Types.BINARY;
import static java.sql.Types.BIT;
import static java.sql.Types.BLOB;
import static java.sql.Types.BOOLEAN;
import static java.sql.Types.CHAR;
import static java.sql.Types.CLOB;
import static java.sql.Types.DATE;
import static java.sql.Types.DECIMAL;
import static java.sql.Types.DOUBLE;
import static java.sql.Types.FLOAT;
import static java.sql.Types.INTEGER;
import static java.sql.Types.LONGNVARCHAR;
import static java.sql.Types.LONGVARBINARY;
import static java.sql.Types.LONGVARCHAR;
import static java.sql.Types.NCHAR;
import static java.sql.Types.NCLOB;
import static java.sql.Types.NUMERIC;
import static java.sql.Types.NVARCHAR;
import static java.sql.Types.OTHER;
import static java.sql.Types.REAL;
import static java.sql.Types.ROWID;
import static java.sql.Types.SMALLINT;
import static java.sql.Types.SQLXML;
import static java.sql.Types.TIME;
import static java.sql.Types.TIMESTAMP;
import static java.sql.Types.TIMESTAMP_WITH_TIMEZONE;
import static java.sql.Types.TINYINT;
import static java.sql.Types.VARBINARY;
import static java.sql.Types.VARCHAR;
import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;
import org.apache.avro.SchemaBuilder;
import org.apache.avro.SchemaBuilder.BaseTypeBuilder;
import org.apache.avro.SchemaBuilder.FieldAssembler;
import org.apache.avro.SchemaBuilder.NullDefault;
import org.apache.avro.SchemaBuilder.UnionAccumulator;
import org.apache.avro.UnresolvedUnionException;
import org.apache.avro.file.CodecFactory;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.DatumWriter;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.nifi.avro.AvroTypeUtil;
import org.apache.nifi.serialization.record.util.DataTypeUtils;
import javax.xml.bind.DatatypeConverter;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
@ -84,26 +72,37 @@ import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;
import org.apache.avro.SchemaBuilder;
import org.apache.avro.SchemaBuilder.BaseTypeBuilder;
import org.apache.avro.SchemaBuilder.FieldAssembler;
import org.apache.avro.SchemaBuilder.NullDefault;
import org.apache.avro.SchemaBuilder.UnionAccumulator;
import org.apache.avro.file.CodecFactory;
import org.apache.avro.UnresolvedUnionException;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.DatumWriter;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.avro.AvroTypeUtil;
import org.apache.nifi.serialization.record.util.DataTypeUtils;
import javax.xml.bind.DatatypeConverter;
import static java.sql.Types.ARRAY;
import static java.sql.Types.BIGINT;
import static java.sql.Types.BINARY;
import static java.sql.Types.BIT;
import static java.sql.Types.BLOB;
import static java.sql.Types.BOOLEAN;
import static java.sql.Types.CHAR;
import static java.sql.Types.CLOB;
import static java.sql.Types.DATE;
import static java.sql.Types.DECIMAL;
import static java.sql.Types.DOUBLE;
import static java.sql.Types.FLOAT;
import static java.sql.Types.INTEGER;
import static java.sql.Types.LONGNVARCHAR;
import static java.sql.Types.LONGVARBINARY;
import static java.sql.Types.LONGVARCHAR;
import static java.sql.Types.NCHAR;
import static java.sql.Types.NCLOB;
import static java.sql.Types.NUMERIC;
import static java.sql.Types.NVARCHAR;
import static java.sql.Types.OTHER;
import static java.sql.Types.REAL;
import static java.sql.Types.ROWID;
import static java.sql.Types.SMALLINT;
import static java.sql.Types.SQLXML;
import static java.sql.Types.TIME;
import static java.sql.Types.TIMESTAMP;
import static java.sql.Types.TIMESTAMP_WITH_TIMEZONE;
import static java.sql.Types.TINYINT;
import static java.sql.Types.VARBINARY;
import static java.sql.Types.VARCHAR;
/**
* JDBC / SQL common functions.
@ -576,7 +575,7 @@ public class JdbcCommon {
case DECIMAL:
case NUMERIC:
if (options.useLogicalTypes) {
final int decimalPrecision;
int decimalPrecision;
final int decimalScale;
if (meta.getPrecision(i) > 0) {
// When database returns a certain precision, we can rely on that.
@ -593,6 +592,11 @@ public class JdbcCommon {
// Default scale is used to preserve decimals in such case.
decimalScale = meta.getScale(i) > 0 ? meta.getScale(i) : options.defaultScale;
}
// Scale can be bigger than precision in some cases (Oracle, e.g.) If this is the case, assume precision refers to the number of
// decimal digits and thus precision = scale
if (decimalScale > decimalPrecision) {
decimalPrecision = decimalScale;
}
final LogicalTypes.Decimal decimal = LogicalTypes.decimal(decimalPrecision, decimalScale);
addNullableField(builder, columnName,
u -> u.type(decimal.addToSchema(SchemaBuilder.builder().bytesType())));

View File

@ -428,6 +428,15 @@ public class TestJdbcCommon {
testConvertToAvroStreamForBigDecimal(bigDecimal, dbPrecision, 24, 24, expectedScale);
}
@Test
public void testConvertToAvroStreamForBigDecimalWithScaleLargerThanPrecision() throws SQLException, IOException {
final int expectedScale = 6; // Scale can be larger than precision in Oracle
final int dbPrecision = 5;
final BigDecimal bigDecimal = new BigDecimal("0.000123", new MathContext(dbPrecision));
// If db doesn't return a precision, default precision should be used.
testConvertToAvroStreamForBigDecimal(bigDecimal, dbPrecision, 10, expectedScale, expectedScale);
}
private void testConvertToAvroStreamForBigDecimal(BigDecimal bigDecimal, int dbPrecision, int defaultPrecision, int expectedPrecision, int expectedScale) throws SQLException, IOException {
final ResultSetMetaData metadata = mock(ResultSetMetaData.class);