NIFI-10508: When inferring data types for values, allow float and double to encapsulate byte/short/int/long values

Signed-off-by: Matthew Burgess <mattyb149@apache.org>

This closes #6421
This commit is contained in:
Mark Payne 2022-09-14 16:03:29 -04:00 committed by Matthew Burgess
parent 860d550435
commit ec09c56e93
No known key found for this signature in database
GPG Key ID: 05D3DEB8126DAD24
8 changed files with 260 additions and 129 deletions

View File

@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.serialization.record.util;
import org.apache.nifi.serialization.record.DataType;
import org.apache.nifi.serialization.record.RecordFieldType;
import org.apache.nifi.serialization.record.type.ChoiceDataType;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/**
* A container class, to which multiple DataTypes can be added, such that
* adding any two types where one is more narrow than the other will result
* in combining the two types into the wider type.
*/
public class DataTypeSet {
private final List<DataType> types = new ArrayList<>();
/**
* Adds the given data type to the set of types to consider
* @param dataType the data type to add
*/
public void add(final DataType dataType) {
if (dataType == null) {
return;
}
if (dataType.getFieldType() == RecordFieldType.CHOICE) {
final ChoiceDataType choiceDataType = (ChoiceDataType) dataType;
choiceDataType.getPossibleSubTypes().forEach(this::add);
return;
}
if (types.contains(dataType)) {
return;
}
DataType toRemove = null;
DataType toAdd = null;
for (final DataType currentType : types) {
final Optional<DataType> widerType = DataTypeUtils.getWiderType(currentType, dataType);
if (widerType.isPresent()) {
toRemove = currentType;
toAdd = widerType.get();
}
}
if (toRemove != null) {
types.remove(toRemove);
}
types.add( toAdd == null ? dataType : toAdd );
}
/**
* @return the combined types
*/
public List<DataType> getTypes() {
return new ArrayList<>(types);
}
}

View File

@ -64,7 +64,6 @@ import java.util.EnumMap;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -590,34 +589,8 @@ public class DataTypeUtils {
m.forEach((k, v) -> map.put(k == null ? null : k.toString(), v)); m.forEach((k, v) -> map.put(k == null ? null : k.toString(), v));
} }
return inferRecordDataType(map); return inferRecordDataType(map);
// // Check if all types are the same.
// if (map.isEmpty()) {
// return RecordFieldType.MAP.getMapDataType(RecordFieldType.STRING.getDataType());
// }
//
// Object valueFromMap = null;
// Class<?> valueClass = null;
// for (final Object val : map.values()) {
// if (val == null) {
// continue;
// }
//
// valueFromMap = val;
// final Class<?> currentValClass = val.getClass();
// if (valueClass == null) {
// valueClass = currentValClass;
// } else {
// // If we have two elements that are of different types, then we cannot have a Map. Must be a Record.
// if (valueClass != currentValClass) {
// return inferRecordDataType(map);
// }
// }
// }
//
// // All values appear to be of the same type, so assume that it's a map.
// final DataType elementDataType = inferDataType(valueFromMap, RecordFieldType.STRING.getDataType());
// return RecordFieldType.MAP.getMapDataType(elementDataType);
} }
if (value.getClass().isArray()) { if (value.getClass().isArray()) {
DataType mergedDataType = null; DataType mergedDataType = null;
@ -633,8 +606,9 @@ public class DataTypeUtils {
return RecordFieldType.ARRAY.getArrayDataType(mergedDataType); return RecordFieldType.ARRAY.getArrayDataType(mergedDataType);
} }
if (value instanceof Iterable) { if (value instanceof Iterable) {
final Iterable iterable = (Iterable<?>) value; final Iterable<?> iterable = (Iterable<?>) value;
DataType mergedDataType = null; DataType mergedDataType = null;
for (final Object arrayValue : iterable) { for (final Object arrayValue : iterable) {
@ -1998,33 +1972,34 @@ public class DataTypeUtils {
return widerType.get(); return widerType.get();
} }
final Set<DataType> possibleTypes = new LinkedHashSet<>(); final DataTypeSet dataTypeSet = new DataTypeSet();
if (thisDataType.getFieldType() == RecordFieldType.CHOICE) { dataTypeSet.add(thisDataType);
possibleTypes.addAll(((ChoiceDataType) thisDataType).getPossibleSubTypes()); dataTypeSet.add(otherDataType);
} else {
possibleTypes.add(thisDataType);
}
if (otherDataType.getFieldType() == RecordFieldType.CHOICE) { final List<DataType> possibleChildTypes = dataTypeSet.getTypes();
possibleTypes.addAll(((ChoiceDataType) otherDataType).getPossibleSubTypes()); possibleChildTypes.sort(Comparator.comparing(DataType::getFieldType));
} else {
possibleTypes.add(otherDataType);
}
ArrayList<DataType> possibleChildTypes = new ArrayList<>(possibleTypes);
Collections.sort(possibleChildTypes, Comparator.comparing(DataType::getFieldType));
return RecordFieldType.CHOICE.getChoiceDataType(possibleChildTypes); return RecordFieldType.CHOICE.getChoiceDataType(possibleChildTypes);
} }
} }
public static Optional<DataType> getWiderType(final DataType thisDataType, final DataType otherDataType) { public static Optional<DataType> getWiderType(final DataType thisDataType, final DataType otherDataType) {
if (thisDataType == null) {
return Optional.ofNullable(otherDataType);
}
if (otherDataType == null) {
return Optional.of(thisDataType);
}
final RecordFieldType thisFieldType = thisDataType.getFieldType(); final RecordFieldType thisFieldType = thisDataType.getFieldType();
final RecordFieldType otherFieldType = otherDataType.getFieldType(); final RecordFieldType otherFieldType = otherDataType.getFieldType();
final int thisIntTypeValue = getIntegerTypeValue(thisFieldType); final int thisIntTypeValue = getIntegerTypeValue(thisFieldType);
final int otherIntTypeValue = getIntegerTypeValue(otherFieldType); final int otherIntTypeValue = getIntegerTypeValue(otherFieldType);
if (thisIntTypeValue > -1 && otherIntTypeValue > -1) { final boolean thisIsInt = thisIntTypeValue > -1;
final boolean otherIsInt = otherIntTypeValue > -1;
if (thisIsInt && otherIsInt) {
if (thisIntTypeValue > otherIntTypeValue) { if (thisIntTypeValue > otherIntTypeValue) {
return Optional.of(thisDataType); return Optional.of(thisDataType);
} }
@ -2032,25 +2007,37 @@ public class DataTypeUtils {
return Optional.of(otherDataType); return Optional.of(otherDataType);
} }
final boolean otherIsDecimal = isDecimalType(otherFieldType);
switch (thisFieldType) { switch (thisFieldType) {
case BYTE:
case SHORT:
case INT:
case LONG:
if (otherIsDecimal) {
return Optional.of(otherDataType);
}
break;
case FLOAT: case FLOAT:
if (otherFieldType == RecordFieldType.DOUBLE) { if (otherFieldType == RecordFieldType.DOUBLE || otherFieldType == RecordFieldType.DECIMAL) {
return Optional.of(otherDataType);
} else if (otherFieldType == RecordFieldType.DECIMAL) {
return Optional.of(otherDataType); return Optional.of(otherDataType);
} }
if (otherFieldType == RecordFieldType.BYTE || otherFieldType == RecordFieldType.SHORT || otherFieldType == RecordFieldType.INT || otherFieldType == RecordFieldType.LONG) {
return Optional.of(thisDataType);
}
break; break;
case DOUBLE: case DOUBLE:
if (otherFieldType == RecordFieldType.FLOAT) { if (otherFieldType == RecordFieldType.DECIMAL) {
return Optional.of(thisDataType);
} else if (otherFieldType == RecordFieldType.DECIMAL) {
return Optional.of(otherDataType); return Optional.of(otherDataType);
} }
if (otherFieldType == RecordFieldType.BYTE || otherFieldType == RecordFieldType.SHORT || otherFieldType == RecordFieldType.INT || otherFieldType == RecordFieldType.LONG
|| otherFieldType == RecordFieldType.FLOAT) {
return Optional.of(thisDataType);
}
break; break;
case DECIMAL: case DECIMAL:
if (otherFieldType == RecordFieldType.DOUBLE) { if (otherFieldType == RecordFieldType.DOUBLE || otherFieldType == RecordFieldType.FLOAT || otherIsInt) {
return Optional.of(thisDataType);
} else if (otherFieldType == RecordFieldType.FLOAT) {
return Optional.of(thisDataType); return Optional.of(thisDataType);
} else if (otherFieldType == RecordFieldType.DECIMAL) { } else if (otherFieldType == RecordFieldType.DECIMAL) {
final DecimalDataType thisDecimalDataType = (DecimalDataType) thisDataType; final DecimalDataType thisDecimalDataType = (DecimalDataType) thisDataType;
@ -2062,12 +2049,13 @@ public class DataTypeUtils {
} }
break; break;
case CHAR: case CHAR:
case UUID:
if (otherFieldType == RecordFieldType.STRING) { if (otherFieldType == RecordFieldType.STRING) {
return Optional.of(otherDataType); return Optional.of(otherDataType);
} }
break; break;
case STRING: case STRING:
if (otherFieldType == RecordFieldType.CHAR) { if (otherFieldType == RecordFieldType.CHAR || otherFieldType == RecordFieldType.UUID) {
return Optional.of(thisDataType); return Optional.of(thisDataType);
} }
break; break;
@ -2076,6 +2064,17 @@ public class DataTypeUtils {
return Optional.empty(); return Optional.empty();
} }
private static boolean isDecimalType(final RecordFieldType fieldType) {
switch (fieldType) {
case FLOAT:
case DOUBLE:
case DECIMAL:
return true;
default:
return false;
}
}
private static int getIntegerTypeValue(final RecordFieldType fieldType) { private static int getIntegerTypeValue(final RecordFieldType fieldType) {
switch (fieldType) { switch (fieldType) {
case BIGINT: case BIGINT:

View File

@ -94,6 +94,12 @@ public class TestDataTypeUtils {
assertEquals(ts.getTime(), sDate.getTime(), "Times didn't match"); assertEquals(ts.getTime(), sDate.getTime(), "Times didn't match");
} }
@Test
public void testIntDoubleWiderType() {
assertEquals(Optional.of(RecordFieldType.DOUBLE.getDataType()), DataTypeUtils.getWiderType(RecordFieldType.INT.getDataType(), RecordFieldType.DOUBLE.getDataType()));
assertEquals(Optional.of(RecordFieldType.DOUBLE.getDataType()), DataTypeUtils.getWiderType(RecordFieldType.DOUBLE.getDataType(), RecordFieldType.INT.getDataType()));
}
/* /*
* This was a bug in NiFi 1.8 where converting from a Timestamp to a Date with the record path API * This was a bug in NiFi 1.8 where converting from a Timestamp to a Date with the record path API
* would throw an exception. * would throw an exception.

View File

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.serialization.record.util;
import org.apache.nifi.serialization.record.RecordFieldType;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collections;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestDataTypeSet {
@Test
public void testCombineNarrowThenWider() {
final DataTypeSet set = new DataTypeSet();
set.add(RecordFieldType.INT.getDataType());
set.add(RecordFieldType.DOUBLE.getDataType());
assertEquals(Collections.singletonList(RecordFieldType.DOUBLE.getDataType()), set.getTypes());
}
@Test
public void testAddIncompatible() {
final DataTypeSet set = new DataTypeSet();
set.add(RecordFieldType.INT.getDataType());
set.add(RecordFieldType.BOOLEAN.getDataType());
assertEquals(Arrays.asList(RecordFieldType.INT.getDataType(), RecordFieldType.BOOLEAN.getDataType()), set.getTypes());
}
@Test
public void addSingleType() {
final DataTypeSet set = new DataTypeSet();
set.add(RecordFieldType.INT.getDataType());
assertEquals(Collections.singletonList(RecordFieldType.INT.getDataType()), set.getTypes());
}
@Test
public void testCombineWiderThenNarrow() {
final DataTypeSet set = new DataTypeSet();
set.add(RecordFieldType.DOUBLE.getDataType());
set.add(RecordFieldType.INT.getDataType());
assertEquals(Collections.singletonList(RecordFieldType.DOUBLE.getDataType()), set.getTypes());
}
@Test
public void testAddNothing() {
final DataTypeSet set = new DataTypeSet();
assertEquals(Collections.emptyList(), set.getTypes());
}
}

View File

@ -193,11 +193,6 @@ public class TestStandardSchemaValidator {
whenValueIsAcceptedAsDataTypeThenConsideredAsValid(Integer.MAX_VALUE, RecordFieldType.DECIMAL); whenValueIsAcceptedAsDataTypeThenConsideredAsValid(Integer.MAX_VALUE, RecordFieldType.DECIMAL);
} }
@Test
public void testIntegerOutsideRangeIsConsideredAsInvalid() {
whenValueIsNotAcceptedAsDataTypeThenConsideredAsInvalid(MAX_PRECISE_WHOLE_IN_FLOAT.intValue() + 1, RecordFieldType.FLOAT);
// Double handles integer completely
}
@Test @Test
public void testLongWithinRangeIsConsideredToBeValidFloatingPoint() { public void testLongWithinRangeIsConsideredToBeValidFloatingPoint() {
@ -206,12 +201,6 @@ public class TestStandardSchemaValidator {
whenValueIsAcceptedAsDataTypeThenConsideredAsValid(Long.MAX_VALUE, RecordFieldType.DECIMAL); whenValueIsAcceptedAsDataTypeThenConsideredAsValid(Long.MAX_VALUE, RecordFieldType.DECIMAL);
} }
@Test
public void testLongOutsideRangeIsConsideredAsInvalid() {
whenValueIsNotAcceptedAsDataTypeThenConsideredAsInvalid(MAX_PRECISE_WHOLE_IN_FLOAT + 1, RecordFieldType.FLOAT);
whenValueIsNotAcceptedAsDataTypeThenConsideredAsInvalid(MAX_PRECISE_WHOLE_IN_DOUBLE + 1, RecordFieldType.DOUBLE);
}
@Test @Test
public void testBigintWithinRangeIsConsideredToBeValidFloatingPoint() { public void testBigintWithinRangeIsConsideredToBeValidFloatingPoint() {
whenValueIsAcceptedAsDataTypeThenConsideredAsValid(BigInteger.valueOf(5L), RecordFieldType.FLOAT); whenValueIsAcceptedAsDataTypeThenConsideredAsValid(BigInteger.valueOf(5L), RecordFieldType.FLOAT);

View File

@ -120,17 +120,17 @@ public class FlowFileEnumerator implements Enumerator<Object> {
return filtered; return filtered;
} }
private Object cast(Object o) { private Object cast(final Object toCast) {
if (o == null) { if (toCast == null) {
return null; return null;
} else if (o.getClass().isArray()) { } else if (toCast.getClass().isArray()) {
List<Object> l = new ArrayList(Array.getLength(o)); final List<Object> list = new ArrayList<>(Array.getLength(toCast));
for (int i = 0; i < Array.getLength(o); i++) { for (int i = 0; i < Array.getLength(toCast); i++) {
l.add(Array.get(o, i)); list.add(Array.get(toCast, i));
} }
return l; return list;
} else { } else {
return o; return toCast;
} }
} }

View File

@ -34,7 +34,7 @@ public class FieldTypeInference {
// unique value for the data type, and so this paradigm allows us to avoid the cost of creating // unique value for the data type, and so this paradigm allows us to avoid the cost of creating
// and using the HashSet. // and using the HashSet.
private DataType singleDataType = null; private DataType singleDataType = null;
private Set<DataType> possibleDataTypes = new HashSet<>(); private final Set<DataType> possibleDataTypes = new HashSet<>();
public void addPossibleDataType(final DataType dataType) { public void addPossibleDataType(final DataType dataType) {
if (dataType == null) { if (dataType == null) {
@ -73,17 +73,17 @@ public class FieldTypeInference {
possibleDataTypes.add(singleDataType); possibleDataTypes.add(singleDataType);
} }
for (DataType possibleDataType : possibleDataTypes) { for (final DataType possibleDataType : possibleDataTypes) {
RecordFieldType possibleFieldType = possibleDataType.getFieldType(); final RecordFieldType possibleFieldType = possibleDataType.getFieldType();
if (!possibleFieldType.equals(RecordFieldType.STRING) && possibleFieldType.isWiderThan(additionalFieldType)) { if (!possibleFieldType.equals(RecordFieldType.STRING) && possibleFieldType.isWiderThan(additionalFieldType)) {
return; return;
} }
} }
Iterator<DataType> possibleDataTypeIterator = possibleDataTypes.iterator(); final Iterator<DataType> possibleDataTypeIterator = possibleDataTypes.iterator();
while (possibleDataTypeIterator.hasNext()) { while (possibleDataTypeIterator.hasNext()) {
DataType possibleDataType = possibleDataTypeIterator.next(); final DataType possibleDataType = possibleDataTypeIterator.next();
RecordFieldType possibleFieldType = possibleDataType.getFieldType(); final RecordFieldType possibleFieldType = possibleDataType.getFieldType();
if (!additionalFieldType.equals(RecordFieldType.STRING) && additionalFieldType.isWiderThan(possibleFieldType)) { if (!additionalFieldType.equals(RecordFieldType.STRING) && additionalFieldType.isWiderThan(possibleFieldType)) {
possibleDataTypeIterator.remove(); possibleDataTypeIterator.remove();

View File

@ -20,11 +20,13 @@ import org.apache.nifi.serialization.SimpleRecordSchema;
import org.apache.nifi.serialization.record.DataType; import org.apache.nifi.serialization.record.DataType;
import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordField;
import org.apache.nifi.serialization.record.RecordFieldType; import org.apache.nifi.serialization.record.RecordFieldType;
import org.apache.nifi.serialization.record.RecordSchema;
import org.apache.nifi.serialization.record.type.ChoiceDataType; import org.apache.nifi.serialization.record.type.ChoiceDataType;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -40,6 +42,25 @@ public class TestFieldTypeInference {
testSubject = new FieldTypeInference(); testSubject = new FieldTypeInference();
} }
@Test
public void testIntegerCombinedWithDouble() {
final FieldTypeInference inference = new FieldTypeInference();
inference.addPossibleDataType(RecordFieldType.INT.getDataType());
inference.addPossibleDataType(RecordFieldType.DOUBLE.getDataType());
assertEquals(RecordFieldType.DOUBLE.getDataType(), inference.toDataType());
}
@Test
public void testIntegerCombinedWithFloat() {
final FieldTypeInference inference = new FieldTypeInference();
inference.addPossibleDataType(RecordFieldType.INT.getDataType());
inference.addPossibleDataType(RecordFieldType.FLOAT.getDataType());
assertEquals(RecordFieldType.FLOAT.getDataType(), inference.toDataType());
}
@Test @Test
public void testToDataTypeWith_SHORT_INT_LONG_shouldReturn_LONG() { public void testToDataTypeWith_SHORT_INT_LONG_shouldReturn_LONG() {
// GIVEN // GIVEN
@ -58,20 +79,13 @@ public class TestFieldTypeInference {
@Test @Test
public void testToDataTypeWith_INT_FLOAT_ShouldReturn_INT_FLOAT() { public void testToDataTypeWith_INT_FLOAT_ShouldReturn_INT_FLOAT() {
// GIVEN final List<DataType> dataTypes = Arrays.asList(
List<DataType> dataTypes = Arrays.asList(
RecordFieldType.INT.getDataType(), RecordFieldType.INT.getDataType(),
RecordFieldType.FLOAT.getDataType() RecordFieldType.FLOAT.getDataType()
); );
Set<DataType> expected = new HashSet<>(Arrays.asList( final DataType expected = RecordFieldType.FLOAT.getDataType();
RecordFieldType.INT.getDataType(), runWithAllPermutations(this::testToDataTypeShouldReturnSingleType, dataTypes, expected);
RecordFieldType.FLOAT.getDataType()
));
// WHEN
// THEN
runWithAllPermutations(this::testToDataTypeShouldReturnChoice, dataTypes, expected);
} }
@Test @Test
@ -94,52 +108,39 @@ public class TestFieldTypeInference {
} }
@Test @Test
public void testToDataTypeWith_INT_FLOAT_STRING_shouldReturn_INT_FLOAT_STRING() { public void testToDataTypeWith_INT_FLOAT_STRING_shouldReturn_FLOAT_STRING() {
// GIVEN final List<DataType> dataTypes = Arrays.asList(
List<DataType> dataTypes = Arrays.asList(
RecordFieldType.INT.getDataType(), RecordFieldType.INT.getDataType(),
RecordFieldType.FLOAT.getDataType(), RecordFieldType.FLOAT.getDataType(),
RecordFieldType.STRING.getDataType() RecordFieldType.STRING.getDataType()
); );
Set<DataType> expected = new HashSet<>(Arrays.asList( final Set<DataType> expected = new HashSet<>(Arrays.asList(
RecordFieldType.INT.getDataType(),
RecordFieldType.FLOAT.getDataType(), RecordFieldType.FLOAT.getDataType(),
RecordFieldType.STRING.getDataType() RecordFieldType.STRING.getDataType()
)); ));
// WHEN
// THEN
runWithAllPermutations(this::testToDataTypeShouldReturnChoice, dataTypes, expected); runWithAllPermutations(this::testToDataTypeShouldReturnChoice, dataTypes, expected);
} }
@Test @Test
public void testToDataTypeWithMultipleRecord() { public void testToDataTypeWithMultipleRecord() {
// GIVEN final String fieldName = "fieldName";
String fieldName = "fieldName"; final DataType intType = RecordFieldType.INT.getDataType();
DataType fieldType1 = RecordFieldType.INT.getDataType(); final DataType floatType = RecordFieldType.FLOAT.getDataType();
DataType fieldType2 = RecordFieldType.FLOAT.getDataType(); final DataType stringType = RecordFieldType.STRING.getDataType();
DataType fieldType3 = RecordFieldType.STRING.getDataType();
List<DataType> dataTypes = Arrays.asList( final List<DataType> dataTypes = Arrays.asList(
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, fieldType1)), RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, intType)),
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, fieldType2)), RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, floatType)),
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, fieldType3)), RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, stringType)),
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, fieldType2)) RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, floatType))
); );
DataType expected = RecordFieldType.RECORD.getRecordDataType(createRecordSchema( final RecordSchema expectedSchema = createRecordSchema(fieldName, RecordFieldType.CHOICE.getChoiceDataType(floatType, stringType));
fieldName, final DataType expecteDataType = RecordFieldType.RECORD.getRecordDataType(expectedSchema);
RecordFieldType.CHOICE.getChoiceDataType(
fieldType1,
fieldType2,
fieldType3
)
));
// WHEN runWithAllPermutations(this::testToDataTypeShouldReturnSingleType, dataTypes, expecteDataType);
// THEN
runWithAllPermutations(this::testToDataTypeShouldReturnSingleType, dataTypes, expected);
} }
@Test @Test
@ -192,8 +193,8 @@ public class TestFieldTypeInference {
} }
private SimpleRecordSchema createRecordSchema(String fieldName, DataType fieldType) { private SimpleRecordSchema createRecordSchema(String fieldName, DataType fieldType) {
return new SimpleRecordSchema(Arrays.asList( return new SimpleRecordSchema(Collections.singletonList(
new RecordField(fieldName, fieldType) new RecordField(fieldName, fieldType)
)); ));
} }
@ -202,28 +203,18 @@ public class TestFieldTypeInference {
} }
private Void testToDataTypeShouldReturnChoice(List<DataType> dataTypes, Set<DataType> expected) { private Void testToDataTypeShouldReturnChoice(List<DataType> dataTypes, Set<DataType> expected) {
// GIVEN
dataTypes.forEach(testSubject::addPossibleDataType); dataTypes.forEach(testSubject::addPossibleDataType);
// WHEN
DataType actual = testSubject.toDataType(); DataType actual = testSubject.toDataType();
// THEN
assertEquals(expected, new HashSet<>(((ChoiceDataType) actual).getPossibleSubTypes())); assertEquals(expected, new HashSet<>(((ChoiceDataType) actual).getPossibleSubTypes()));
return null; return null;
} }
private Void testToDataTypeShouldReturnSingleType(List<DataType> dataTypes, DataType expected) { private Void testToDataTypeShouldReturnSingleType(List<DataType> dataTypes, DataType expected) {
// GIVEN
dataTypes.forEach(testSubject::addPossibleDataType); dataTypes.forEach(testSubject::addPossibleDataType);
// WHEN
DataType actual = testSubject.toDataType(); DataType actual = testSubject.toDataType();
// THEN
assertEquals(expected, actual); assertEquals(expected, actual);
return null; return null;
} }
} }