From ec09c56e933c8418c97f89bfbcf898c8338b5d06 Mon Sep 17 00:00:00 2001 From: Mark Payne Date: Wed, 14 Sep 2022 16:03:29 -0400 Subject: [PATCH] NIFI-10508: When inferring data types for values, allow float and double to encapsulate byte/short/int/long values Signed-off-by: Matthew Burgess This closes #6421 --- .../record/util/DataTypeSet.java | 78 +++++++++++++ .../record/util/DataTypeUtils.java | 107 +++++++++--------- .../record/TestDataTypeUtils.java | 6 + .../record/util/TestDataTypeSet.java | 68 +++++++++++ .../TestStandardSchemaValidator.java | 11 -- .../nifi/queryrecord/FlowFileEnumerator.java | 16 +-- .../schema/inference/FieldTypeInference.java | 12 +- .../inference/TestFieldTypeInference.java | 91 +++++++-------- 8 files changed, 260 insertions(+), 129 deletions(-) create mode 100644 nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeSet.java create mode 100644 nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/util/TestDataTypeSet.java diff --git a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeSet.java b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeSet.java new file mode 100644 index 0000000000..33ee847571 --- /dev/null +++ b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeSet.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.serialization.record.util; + +import org.apache.nifi.serialization.record.DataType; +import org.apache.nifi.serialization.record.RecordFieldType; +import org.apache.nifi.serialization.record.type.ChoiceDataType; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * A container class, to which multiple DataTypes can be added, such that + * adding any two types where one is more narrow than the other will result + * in combining the two types into the wider type. + */ +public class DataTypeSet { + private final List types = new ArrayList<>(); + + /** + * Adds the given data type to the set of types to consider + * @param dataType the data type to add + */ + public void add(final DataType dataType) { + if (dataType == null) { + return; + } + + if (dataType.getFieldType() == RecordFieldType.CHOICE) { + final ChoiceDataType choiceDataType = (ChoiceDataType) dataType; + choiceDataType.getPossibleSubTypes().forEach(this::add); + return; + } + + if (types.contains(dataType)) { + return; + } + + DataType toRemove = null; + DataType toAdd = null; + for (final DataType currentType : types) { + final Optional widerType = DataTypeUtils.getWiderType(currentType, dataType); + if (widerType.isPresent()) { + toRemove = currentType; + toAdd = widerType.get(); + } + } + + if (toRemove != null) { + types.remove(toRemove); + } + + types.add( toAdd == null ? dataType : toAdd ); + } + + /** + * @return the combined types + */ + public List getTypes() { + return new ArrayList<>(types); + } +} diff --git a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java index e281d6cfbe..21e3b71827 100644 --- a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java +++ b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java @@ -64,7 +64,6 @@ import java.util.EnumMap; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; -import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -590,34 +589,8 @@ public class DataTypeUtils { m.forEach((k, v) -> map.put(k == null ? null : k.toString(), v)); } return inferRecordDataType(map); -// // Check if all types are the same. -// if (map.isEmpty()) { -// return RecordFieldType.MAP.getMapDataType(RecordFieldType.STRING.getDataType()); -// } -// -// Object valueFromMap = null; -// Class valueClass = null; -// for (final Object val : map.values()) { -// if (val == null) { -// continue; -// } -// -// valueFromMap = val; -// final Class currentValClass = val.getClass(); -// if (valueClass == null) { -// valueClass = currentValClass; -// } else { -// // If we have two elements that are of different types, then we cannot have a Map. Must be a Record. -// if (valueClass != currentValClass) { -// return inferRecordDataType(map); -// } -// } -// } -// -// // All values appear to be of the same type, so assume that it's a map. -// final DataType elementDataType = inferDataType(valueFromMap, RecordFieldType.STRING.getDataType()); -// return RecordFieldType.MAP.getMapDataType(elementDataType); } + if (value.getClass().isArray()) { DataType mergedDataType = null; @@ -633,8 +606,9 @@ public class DataTypeUtils { return RecordFieldType.ARRAY.getArrayDataType(mergedDataType); } + if (value instanceof Iterable) { - final Iterable iterable = (Iterable) value; + final Iterable iterable = (Iterable) value; DataType mergedDataType = null; for (final Object arrayValue : iterable) { @@ -1998,33 +1972,34 @@ public class DataTypeUtils { return widerType.get(); } - final Set possibleTypes = new LinkedHashSet<>(); - if (thisDataType.getFieldType() == RecordFieldType.CHOICE) { - possibleTypes.addAll(((ChoiceDataType) thisDataType).getPossibleSubTypes()); - } else { - possibleTypes.add(thisDataType); - } + final DataTypeSet dataTypeSet = new DataTypeSet(); + dataTypeSet.add(thisDataType); + dataTypeSet.add(otherDataType); - if (otherDataType.getFieldType() == RecordFieldType.CHOICE) { - possibleTypes.addAll(((ChoiceDataType) otherDataType).getPossibleSubTypes()); - } else { - possibleTypes.add(otherDataType); - } - - ArrayList possibleChildTypes = new ArrayList<>(possibleTypes); - Collections.sort(possibleChildTypes, Comparator.comparing(DataType::getFieldType)); + final List possibleChildTypes = dataTypeSet.getTypes(); + possibleChildTypes.sort(Comparator.comparing(DataType::getFieldType)); return RecordFieldType.CHOICE.getChoiceDataType(possibleChildTypes); } } public static Optional getWiderType(final DataType thisDataType, final DataType otherDataType) { + if (thisDataType == null) { + return Optional.ofNullable(otherDataType); + } + if (otherDataType == null) { + return Optional.of(thisDataType); + } + final RecordFieldType thisFieldType = thisDataType.getFieldType(); final RecordFieldType otherFieldType = otherDataType.getFieldType(); final int thisIntTypeValue = getIntegerTypeValue(thisFieldType); final int otherIntTypeValue = getIntegerTypeValue(otherFieldType); - if (thisIntTypeValue > -1 && otherIntTypeValue > -1) { + final boolean thisIsInt = thisIntTypeValue > -1; + final boolean otherIsInt = otherIntTypeValue > -1; + + if (thisIsInt && otherIsInt) { if (thisIntTypeValue > otherIntTypeValue) { return Optional.of(thisDataType); } @@ -2032,25 +2007,37 @@ public class DataTypeUtils { return Optional.of(otherDataType); } + final boolean otherIsDecimal = isDecimalType(otherFieldType); + switch (thisFieldType) { + case BYTE: + case SHORT: + case INT: + case LONG: + if (otherIsDecimal) { + return Optional.of(otherDataType); + } + break; case FLOAT: - if (otherFieldType == RecordFieldType.DOUBLE) { - return Optional.of(otherDataType); - } else if (otherFieldType == RecordFieldType.DECIMAL) { + if (otherFieldType == RecordFieldType.DOUBLE || otherFieldType == RecordFieldType.DECIMAL) { return Optional.of(otherDataType); } + if (otherFieldType == RecordFieldType.BYTE || otherFieldType == RecordFieldType.SHORT || otherFieldType == RecordFieldType.INT || otherFieldType == RecordFieldType.LONG) { + return Optional.of(thisDataType); + } break; case DOUBLE: - if (otherFieldType == RecordFieldType.FLOAT) { - return Optional.of(thisDataType); - } else if (otherFieldType == RecordFieldType.DECIMAL) { + if (otherFieldType == RecordFieldType.DECIMAL) { return Optional.of(otherDataType); } + if (otherFieldType == RecordFieldType.BYTE || otherFieldType == RecordFieldType.SHORT || otherFieldType == RecordFieldType.INT || otherFieldType == RecordFieldType.LONG + || otherFieldType == RecordFieldType.FLOAT) { + + return Optional.of(thisDataType); + } break; case DECIMAL: - if (otherFieldType == RecordFieldType.DOUBLE) { - return Optional.of(thisDataType); - } else if (otherFieldType == RecordFieldType.FLOAT) { + if (otherFieldType == RecordFieldType.DOUBLE || otherFieldType == RecordFieldType.FLOAT || otherIsInt) { return Optional.of(thisDataType); } else if (otherFieldType == RecordFieldType.DECIMAL) { final DecimalDataType thisDecimalDataType = (DecimalDataType) thisDataType; @@ -2062,12 +2049,13 @@ public class DataTypeUtils { } break; case CHAR: + case UUID: if (otherFieldType == RecordFieldType.STRING) { return Optional.of(otherDataType); } break; case STRING: - if (otherFieldType == RecordFieldType.CHAR) { + if (otherFieldType == RecordFieldType.CHAR || otherFieldType == RecordFieldType.UUID) { return Optional.of(thisDataType); } break; @@ -2076,6 +2064,17 @@ public class DataTypeUtils { return Optional.empty(); } + private static boolean isDecimalType(final RecordFieldType fieldType) { + switch (fieldType) { + case FLOAT: + case DOUBLE: + case DECIMAL: + return true; + default: + return false; + } + } + private static int getIntegerTypeValue(final RecordFieldType fieldType) { switch (fieldType) { case BIGINT: diff --git a/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java b/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java index 46cd012239..863949b486 100644 --- a/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java +++ b/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java @@ -94,6 +94,12 @@ public class TestDataTypeUtils { assertEquals(ts.getTime(), sDate.getTime(), "Times didn't match"); } + @Test + public void testIntDoubleWiderType() { + assertEquals(Optional.of(RecordFieldType.DOUBLE.getDataType()), DataTypeUtils.getWiderType(RecordFieldType.INT.getDataType(), RecordFieldType.DOUBLE.getDataType())); + assertEquals(Optional.of(RecordFieldType.DOUBLE.getDataType()), DataTypeUtils.getWiderType(RecordFieldType.DOUBLE.getDataType(), RecordFieldType.INT.getDataType())); + } + /* * This was a bug in NiFi 1.8 where converting from a Timestamp to a Date with the record path API * would throw an exception. diff --git a/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/util/TestDataTypeSet.java b/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/util/TestDataTypeSet.java new file mode 100644 index 0000000000..14800688da --- /dev/null +++ b/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/util/TestDataTypeSet.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.serialization.record.util; + +import org.apache.nifi.serialization.record.RecordFieldType; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TestDataTypeSet { + + @Test + public void testCombineNarrowThenWider() { + final DataTypeSet set = new DataTypeSet(); + set.add(RecordFieldType.INT.getDataType()); + set.add(RecordFieldType.DOUBLE.getDataType()); + assertEquals(Collections.singletonList(RecordFieldType.DOUBLE.getDataType()), set.getTypes()); + } + + @Test + public void testAddIncompatible() { + final DataTypeSet set = new DataTypeSet(); + set.add(RecordFieldType.INT.getDataType()); + set.add(RecordFieldType.BOOLEAN.getDataType()); + assertEquals(Arrays.asList(RecordFieldType.INT.getDataType(), RecordFieldType.BOOLEAN.getDataType()), set.getTypes()); + + } + + @Test + public void addSingleType() { + final DataTypeSet set = new DataTypeSet(); + set.add(RecordFieldType.INT.getDataType()); + assertEquals(Collections.singletonList(RecordFieldType.INT.getDataType()), set.getTypes()); + + } + + @Test + public void testCombineWiderThenNarrow() { + final DataTypeSet set = new DataTypeSet(); + set.add(RecordFieldType.DOUBLE.getDataType()); + set.add(RecordFieldType.INT.getDataType()); + assertEquals(Collections.singletonList(RecordFieldType.DOUBLE.getDataType()), set.getTypes()); + } + + @Test + public void testAddNothing() { + final DataTypeSet set = new DataTypeSet(); + assertEquals(Collections.emptyList(), set.getTypes()); + } +} diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/test/java/org/apache/nifi/schema/validation/TestStandardSchemaValidator.java b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/test/java/org/apache/nifi/schema/validation/TestStandardSchemaValidator.java index 17a109bba1..baa1063500 100644 --- a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/test/java/org/apache/nifi/schema/validation/TestStandardSchemaValidator.java +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/test/java/org/apache/nifi/schema/validation/TestStandardSchemaValidator.java @@ -193,11 +193,6 @@ public class TestStandardSchemaValidator { whenValueIsAcceptedAsDataTypeThenConsideredAsValid(Integer.MAX_VALUE, RecordFieldType.DECIMAL); } - @Test - public void testIntegerOutsideRangeIsConsideredAsInvalid() { - whenValueIsNotAcceptedAsDataTypeThenConsideredAsInvalid(MAX_PRECISE_WHOLE_IN_FLOAT.intValue() + 1, RecordFieldType.FLOAT); - // Double handles integer completely - } @Test public void testLongWithinRangeIsConsideredToBeValidFloatingPoint() { @@ -206,12 +201,6 @@ public class TestStandardSchemaValidator { whenValueIsAcceptedAsDataTypeThenConsideredAsValid(Long.MAX_VALUE, RecordFieldType.DECIMAL); } - @Test - public void testLongOutsideRangeIsConsideredAsInvalid() { - whenValueIsNotAcceptedAsDataTypeThenConsideredAsInvalid(MAX_PRECISE_WHOLE_IN_FLOAT + 1, RecordFieldType.FLOAT); - whenValueIsNotAcceptedAsDataTypeThenConsideredAsInvalid(MAX_PRECISE_WHOLE_IN_DOUBLE + 1, RecordFieldType.DOUBLE); - } - @Test public void testBigintWithinRangeIsConsideredToBeValidFloatingPoint() { whenValueIsAcceptedAsDataTypeThenConsideredAsValid(BigInteger.valueOf(5L), RecordFieldType.FLOAT); diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/queryrecord/FlowFileEnumerator.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/queryrecord/FlowFileEnumerator.java index db66c5a80e..1a0d2ba2ec 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/queryrecord/FlowFileEnumerator.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/queryrecord/FlowFileEnumerator.java @@ -120,17 +120,17 @@ public class FlowFileEnumerator implements Enumerator { return filtered; } - private Object cast(Object o) { - if (o == null) { + private Object cast(final Object toCast) { + if (toCast == null) { return null; - } else if (o.getClass().isArray()) { - List l = new ArrayList(Array.getLength(o)); - for (int i = 0; i < Array.getLength(o); i++) { - l.add(Array.get(o, i)); + } else if (toCast.getClass().isArray()) { + final List list = new ArrayList<>(Array.getLength(toCast)); + for (int i = 0; i < Array.getLength(toCast); i++) { + list.add(Array.get(toCast, i)); } - return l; + return list; } else { - return o; + return toCast; } } diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/main/java/org/apache/nifi/schema/inference/FieldTypeInference.java b/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/main/java/org/apache/nifi/schema/inference/FieldTypeInference.java index 1f52cb8357..a4186eef60 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/main/java/org/apache/nifi/schema/inference/FieldTypeInference.java +++ b/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/main/java/org/apache/nifi/schema/inference/FieldTypeInference.java @@ -34,7 +34,7 @@ public class FieldTypeInference { // unique value for the data type, and so this paradigm allows us to avoid the cost of creating // and using the HashSet. private DataType singleDataType = null; - private Set possibleDataTypes = new HashSet<>(); + private final Set possibleDataTypes = new HashSet<>(); public void addPossibleDataType(final DataType dataType) { if (dataType == null) { @@ -73,17 +73,17 @@ public class FieldTypeInference { possibleDataTypes.add(singleDataType); } - for (DataType possibleDataType : possibleDataTypes) { - RecordFieldType possibleFieldType = possibleDataType.getFieldType(); + for (final DataType possibleDataType : possibleDataTypes) { + final RecordFieldType possibleFieldType = possibleDataType.getFieldType(); if (!possibleFieldType.equals(RecordFieldType.STRING) && possibleFieldType.isWiderThan(additionalFieldType)) { return; } } - Iterator possibleDataTypeIterator = possibleDataTypes.iterator(); + final Iterator possibleDataTypeIterator = possibleDataTypes.iterator(); while (possibleDataTypeIterator.hasNext()) { - DataType possibleDataType = possibleDataTypeIterator.next(); - RecordFieldType possibleFieldType = possibleDataType.getFieldType(); + final DataType possibleDataType = possibleDataTypeIterator.next(); + final RecordFieldType possibleFieldType = possibleDataType.getFieldType(); if (!additionalFieldType.equals(RecordFieldType.STRING) && additionalFieldType.isWiderThan(possibleFieldType)) { possibleDataTypeIterator.remove(); diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/test/java/org/apache/nifi/schema/inference/TestFieldTypeInference.java b/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/test/java/org/apache/nifi/schema/inference/TestFieldTypeInference.java index ea2f470095..80b60c61d8 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/test/java/org/apache/nifi/schema/inference/TestFieldTypeInference.java +++ b/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/test/java/org/apache/nifi/schema/inference/TestFieldTypeInference.java @@ -20,11 +20,13 @@ import org.apache.nifi.serialization.SimpleRecordSchema; import org.apache.nifi.serialization.record.DataType; import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordFieldType; +import org.apache.nifi.serialization.record.RecordSchema; import org.apache.nifi.serialization.record.type.ChoiceDataType; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -40,6 +42,25 @@ public class TestFieldTypeInference { testSubject = new FieldTypeInference(); } + @Test + public void testIntegerCombinedWithDouble() { + final FieldTypeInference inference = new FieldTypeInference(); + inference.addPossibleDataType(RecordFieldType.INT.getDataType()); + inference.addPossibleDataType(RecordFieldType.DOUBLE.getDataType()); + + assertEquals(RecordFieldType.DOUBLE.getDataType(), inference.toDataType()); + } + + @Test + public void testIntegerCombinedWithFloat() { + final FieldTypeInference inference = new FieldTypeInference(); + inference.addPossibleDataType(RecordFieldType.INT.getDataType()); + inference.addPossibleDataType(RecordFieldType.FLOAT.getDataType()); + + assertEquals(RecordFieldType.FLOAT.getDataType(), inference.toDataType()); + } + + @Test public void testToDataTypeWith_SHORT_INT_LONG_shouldReturn_LONG() { // GIVEN @@ -58,20 +79,13 @@ public class TestFieldTypeInference { @Test public void testToDataTypeWith_INT_FLOAT_ShouldReturn_INT_FLOAT() { - // GIVEN - List dataTypes = Arrays.asList( + final List dataTypes = Arrays.asList( RecordFieldType.INT.getDataType(), RecordFieldType.FLOAT.getDataType() ); - Set expected = new HashSet<>(Arrays.asList( - RecordFieldType.INT.getDataType(), - RecordFieldType.FLOAT.getDataType() - )); - - // WHEN - // THEN - runWithAllPermutations(this::testToDataTypeShouldReturnChoice, dataTypes, expected); + final DataType expected = RecordFieldType.FLOAT.getDataType(); + runWithAllPermutations(this::testToDataTypeShouldReturnSingleType, dataTypes, expected); } @Test @@ -94,52 +108,39 @@ public class TestFieldTypeInference { } @Test - public void testToDataTypeWith_INT_FLOAT_STRING_shouldReturn_INT_FLOAT_STRING() { - // GIVEN - List dataTypes = Arrays.asList( + public void testToDataTypeWith_INT_FLOAT_STRING_shouldReturn_FLOAT_STRING() { + final List dataTypes = Arrays.asList( RecordFieldType.INT.getDataType(), RecordFieldType.FLOAT.getDataType(), RecordFieldType.STRING.getDataType() ); - Set expected = new HashSet<>(Arrays.asList( - RecordFieldType.INT.getDataType(), + final Set expected = new HashSet<>(Arrays.asList( RecordFieldType.FLOAT.getDataType(), RecordFieldType.STRING.getDataType() )); - // WHEN - // THEN runWithAllPermutations(this::testToDataTypeShouldReturnChoice, dataTypes, expected); } @Test public void testToDataTypeWithMultipleRecord() { - // GIVEN - String fieldName = "fieldName"; - DataType fieldType1 = RecordFieldType.INT.getDataType(); - DataType fieldType2 = RecordFieldType.FLOAT.getDataType(); - DataType fieldType3 = RecordFieldType.STRING.getDataType(); + final String fieldName = "fieldName"; + final DataType intType = RecordFieldType.INT.getDataType(); + final DataType floatType = RecordFieldType.FLOAT.getDataType(); + final DataType stringType = RecordFieldType.STRING.getDataType(); - List dataTypes = Arrays.asList( - RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, fieldType1)), - RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, fieldType2)), - RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, fieldType3)), - RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, fieldType2)) + final List dataTypes = Arrays.asList( + RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, intType)), + RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, floatType)), + RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, stringType)), + RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, floatType)) ); - DataType expected = RecordFieldType.RECORD.getRecordDataType(createRecordSchema( - fieldName, - RecordFieldType.CHOICE.getChoiceDataType( - fieldType1, - fieldType2, - fieldType3 - ) - )); + final RecordSchema expectedSchema = createRecordSchema(fieldName, RecordFieldType.CHOICE.getChoiceDataType(floatType, stringType)); + final DataType expecteDataType = RecordFieldType.RECORD.getRecordDataType(expectedSchema); - // WHEN - // THEN - runWithAllPermutations(this::testToDataTypeShouldReturnSingleType, dataTypes, expected); + runWithAllPermutations(this::testToDataTypeShouldReturnSingleType, dataTypes, expecteDataType); } @Test @@ -192,8 +193,8 @@ public class TestFieldTypeInference { } private SimpleRecordSchema createRecordSchema(String fieldName, DataType fieldType) { - return new SimpleRecordSchema(Arrays.asList( - new RecordField(fieldName, fieldType) + return new SimpleRecordSchema(Collections.singletonList( + new RecordField(fieldName, fieldType) )); } @@ -202,28 +203,18 @@ public class TestFieldTypeInference { } private Void testToDataTypeShouldReturnChoice(List dataTypes, Set expected) { - // GIVEN dataTypes.forEach(testSubject::addPossibleDataType); - // WHEN DataType actual = testSubject.toDataType(); - - // THEN assertEquals(expected, new HashSet<>(((ChoiceDataType) actual).getPossibleSubTypes())); - return null; } private Void testToDataTypeShouldReturnSingleType(List dataTypes, DataType expected) { - // GIVEN dataTypes.forEach(testSubject::addPossibleDataType); - // WHEN DataType actual = testSubject.toDataType(); - - // THEN assertEquals(expected, actual); - return null; } }