From 9141f64ef523bc9a150a8b52f453817dad46a637 Mon Sep 17 00:00:00 2001 From: Mark Bathori Date: Tue, 4 Jun 2024 12:45:41 +0200 Subject: [PATCH] NIFI-13356 Fixed ProtobufReader handling of repeated fields This closes #8922 Signed-off-by: David Handermann --- .../converter/ProtobufDataConverter.java | 82 ++++++++++++++++-- .../protobuf/converter/ValueReader.java | 26 ++++++ .../nifi/services/protobuf/ProtoTestUtil.java | 69 ++++++++++++++- .../converter/TestProtobufDataConverter.java | 46 ++++++++-- .../schema/TestProtoSchemaParser.java | 36 +++++++- .../src/test/resources/test_proto3.desc | Bin 1022 -> 984 bytes .../src/test/resources/test_proto3.proto | 3 +- .../test/resources/test_repeated_proto3.desc | Bin 0 -> 755 bytes .../test/resources/test_repeated_proto3.proto | 46 ++++++++++ 9 files changed, 288 insertions(+), 20 deletions(-) create mode 100644 nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ValueReader.java create mode 100644 nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_repeated_proto3.desc create mode 100644 nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_repeated_proto3.proto diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ProtobufDataConverter.java b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ProtobufDataConverter.java index df5c491fa9..81e0226d4a 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ProtobufDataConverter.java +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ProtobufDataConverter.java @@ -30,6 +30,7 @@ import org.apache.nifi.serialization.record.DataType; import org.apache.nifi.serialization.record.MapRecord; import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.serialization.record.type.ArrayDataType; import org.apache.nifi.serialization.record.type.RecordDataType; import org.apache.nifi.serialization.record.util.DataTypeUtils; import org.apache.nifi.services.protobuf.FieldType; @@ -38,6 +39,7 @@ import org.apache.nifi.services.protobuf.schema.ProtoSchemaParser; import java.io.IOException; import java.io.InputStream; import java.math.BigInteger; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -47,6 +49,8 @@ import java.util.function.Function; import static com.google.protobuf.CodedInputStream.decodeZigZag32; import static com.google.protobuf.TextFormat.unsignedToString; +import static org.apache.nifi.services.protobuf.FieldType.STRING; +import static org.apache.nifi.services.protobuf.FieldType.BYTES; /** * The class is responsible for creating Record by mapping the provided proto schema fields with the list of Unknown fields parsed from encoded proto data. @@ -154,7 +158,11 @@ public class ProtobufDataConverter { private Optional convertFieldValues(ProtoField protoField, UnknownFieldSet.Field unknownField) throws InvalidProtocolBufferException { if (!unknownField.getLengthDelimitedList().isEmpty()) { - return Optional.of(convertLengthDelimitedFields(protoField, unknownField.getLengthDelimitedList())); + if (protoField.isRepeatable() && !isLengthDelimitedType(protoField)) { + return Optional.of(convertRepeatedFields(protoField, unknownField.getLengthDelimitedList())); + } else { + return Optional.of(convertLengthDelimitedFields(protoField, unknownField.getLengthDelimitedList())); + } } if (!unknownField.getFixed32List().isEmpty()) { return Optional.of(convertFixed32Fields(protoField, unknownField.getFixed32List())); @@ -169,6 +177,34 @@ public class ProtobufDataConverter { return Optional.empty(); } + private Object convertRepeatedFields(ProtoField protoField, List fieldValues) { + final CodedInputStream inputStream = fieldValues.getFirst().newCodedInput(); + final ProtoType protoType = protoField.getProtoType(); + if (protoType.isScalar()) { + final ValueReader valueReader = switch (FieldType.findValue(protoType.getSimpleName())) { + case BOOL -> CodedInputStream::readBool; + case INT32 -> CodedInputStream::readInt32; + case UINT32 -> value -> Integer.toUnsignedLong(value.readUInt32()); + case SINT32 -> CodedInputStream::readSInt32; + case INT64 -> CodedInputStream::readInt64; + case UINT64 -> value -> new BigInteger(unsignedToString(value.readUInt64())); + case SINT64 -> CodedInputStream::readSInt64; + case FIXED32 -> value -> Integer.toUnsignedLong(value.readFixed32()); + case SFIXED32 -> CodedInputStream::readSFixed32; + case FIXED64 -> value -> new BigInteger(unsignedToString(value.readFixed64())); + case SFIXED64 -> CodedInputStream::readSFixed64; + case FLOAT -> CodedInputStream::readFloat; + case DOUBLE -> CodedInputStream::readDouble; + default -> throw new IllegalStateException(String.format("Unexpected type [%s] was received for field [%s]", + protoType.getSimpleName(), protoField.getFieldName())); + }; + return resolveFieldValue(protoField, processRepeatedValues(inputStream, valueReader), value -> value); + } else { + List values = processRepeatedValues(inputStream, CodedInputStream::readEnum); + return resolveFieldValue(protoField, values, value -> convertEnum(value, protoType)); + } + } + /** * Converts a Length-Delimited field value into it's suitable data type. * @@ -197,6 +233,10 @@ public class ProtobufDataConverter { valueConverter = value -> { try { Optional recordDataType = rootRecordSchema.getDataType(protoField.getFieldName()); + if (protoField.isRepeatable()) { + final ArrayDataType arrayDataType = (ArrayDataType) recordDataType.get(); + recordDataType = Optional.ofNullable(arrayDataType.getElementType()); + } RecordSchema recordSchema = recordDataType.map(dataType -> ((RecordDataType) dataType).getChildSchema()).orElse(generateRecordSchema(messageType.getType().toString())); return createRecord(messageType, value, recordSchema); @@ -220,7 +260,7 @@ public class ProtobufDataConverter { final String typeName = protoField.getProtoType().getSimpleName(); final Function valueConverter = switch (FieldType.findValue(typeName)) { - case FIXED32 -> value -> Long.parseLong(unsignedToString(value)); + case FIXED32 -> Integer::toUnsignedLong; case SFIXED32 -> value -> value; case FLOAT -> Float::intBitsToFloat; default -> @@ -276,11 +316,7 @@ public class ProtobufDataConverter { " [%s] is not Varint field type", protoField.getFieldName(), protoType.getSimpleName())); }; } else { - valueConverter = value -> { - final EnumType enumType = (EnumType) schema.getType(protoType); - Objects.requireNonNull(enumType, String.format("Enum with name [%s] not found in the provided proto files", protoType)); - return enumType.constant(Integer.parseInt(value.toString())).getName(); - }; + valueConverter = value -> convertEnum(value.intValue(), protoType); } return resolveFieldValue(protoField, values, valueConverter); @@ -297,7 +333,7 @@ public class ProtobufDataConverter { } if (!protoField.isRepeatable()) { - return resultValues.get(0); + return resultValues.getFirst(); } else { return resultValues.toArray(); } @@ -327,6 +363,12 @@ public class ProtobufDataConverter { return mapResult; } + private String convertEnum(Integer value, ProtoType protoType) { + final EnumType enumType = (EnumType) schema.getType(protoType); + Objects.requireNonNull(enumType, String.format("Enum with name [%s] not found in the provided proto files", protoType)); + return enumType.constant(value).getName(); + } + /** * Process a 'google.protobuf.Any' typed field. The method gets the schema for the message type provided in the 'type_url' property * and parse the serialized message from the 'value' field. The result record will contain only the parsed message's fields. @@ -368,4 +410,28 @@ public class ProtobufDataConverter { private String getQualifiedTypeName(String typeName) { return typeName.substring(typeName.lastIndexOf('/') + 1); } + + private List processRepeatedValues(CodedInputStream input, ValueReader valueReader) { + List result = new ArrayList<>(); + try { + while (input.getBytesUntilLimit() > 0) { + result.add(valueReader.apply(input)); + } + } catch (Exception e) { + throw new IllegalStateException("Unable to parse repeated field", e); + } + return result; + } + + private boolean isLengthDelimitedType(ProtoField protoField) { + boolean lengthDelimitedScalarType = false; + final ProtoType protoType = protoField.getProtoType(); + + if (protoType.isScalar()) { + final FieldType fieldType = FieldType.findValue(protoType.getSimpleName()); + lengthDelimitedScalarType = fieldType.equals(STRING) || fieldType.equals(BYTES); + } + + return lengthDelimitedScalarType || schema.getType(protoType) instanceof MessageType; + } } diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ValueReader.java b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ValueReader.java new file mode 100644 index 0000000000..cff78dea51 --- /dev/null +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ValueReader.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.services.protobuf.converter; + +import java.io.IOException; + +@FunctionalInterface +interface ValueReader { + + R apply(T t) throws IOException; + +} diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/ProtoTestUtil.java b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/ProtoTestUtil.java index 4a10c0ecfd..c0d273da95 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/ProtoTestUtil.java +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/ProtoTestUtil.java @@ -44,6 +44,12 @@ public class ProtoTestUtil { return schemaLoader.loadSchema(); } + public static Schema loadRepeatedProto3TestSchema() { + final SchemaLoader schemaLoader = new SchemaLoader(FileSystems.getDefault()); + schemaLoader.initRoots(Collections.singletonList(Location.get(BASE_TEST_PATH + "test_repeated_proto3.proto")), Collections.emptyList()); + return schemaLoader.loadSchema(); + } + public static Schema loadProto2TestSchema() { final SchemaLoader schemaLoader = new SchemaLoader(FileSystems.getDefault()); schemaLoader.initRoots(Arrays.asList( @@ -76,13 +82,10 @@ public class ProtoTestUtil { DynamicMessage nestedMessage = DynamicMessage .newBuilder(nestedMessageDescriptor) .setField(nestedMessageDescriptor.findFieldByNumber(20), enumValueDescriptor.findValueByNumber(2)) - .addRepeatedField(nestedMessageDescriptor.findFieldByNumber(21), "Repeated 1") - .addRepeatedField(nestedMessageDescriptor.findFieldByNumber(21), "Repeated 2") - .addRepeatedField(nestedMessageDescriptor.findFieldByNumber(21), "Repeated 3") + .setField(nestedMessageDescriptor.findFieldByNumber(21), Arrays.asList(mapEntry1, mapEntry2)) .setField(nestedMessageDescriptor.findFieldByNumber(22), "One Of Option") .setField(nestedMessageDescriptor.findFieldByNumber(23), true) .setField(nestedMessageDescriptor.findFieldByNumber(24), 3) - .setField(nestedMessageDescriptor.findFieldByNumber(25), Arrays.asList(mapEntry1, mapEntry2)) .build(); DynamicMessage message = DynamicMessage @@ -108,6 +111,64 @@ public class ProtoTestUtil { return message.toByteString().newInput(); } + public static InputStream generateInputDataForRepeatedProto3() throws IOException, Descriptors.DescriptorValidationException { + DescriptorProtos.FileDescriptorSet descriptorSet = DescriptorProtos.FileDescriptorSet.parseFrom(new FileInputStream(BASE_TEST_PATH + "test_repeated_proto3.desc")); + Descriptors.FileDescriptor fileDescriptor = Descriptors.FileDescriptor.buildFrom(descriptorSet.getFile(0), new Descriptors.FileDescriptor[0]); + + Descriptors.Descriptor messageDescriptor = fileDescriptor.findMessageTypeByName("RootMessage"); + Descriptors.Descriptor repeatedMessageDescriptor = fileDescriptor.findMessageTypeByName("RepeatedMessage"); + Descriptors.EnumDescriptor enumValueDescriptor = fileDescriptor.findEnumTypeByName("TestEnum"); + + DynamicMessage repeatedMessage1 = DynamicMessage + .newBuilder(repeatedMessageDescriptor) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(1), true) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(1), false) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(2), "Test text1") + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(2), "Test text2") + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(3), Integer.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(3), Integer.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(4), -1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(4), -2) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(5), Integer.MIN_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(5), Integer.MIN_VALUE + 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(6), -2) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(6), -3) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(7), Integer.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(7), Integer.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(8), Double.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(8), Double.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(9), Float.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(9), Float.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(10), "Test bytes1".getBytes()) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(10), "Test bytes2".getBytes()) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(11), Long.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(11), Long.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(12), -1L) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(12), -2L) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(13), Long.MIN_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(13), Long.MIN_VALUE + 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(14), -2L) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(14), -1L) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(15), Long.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(15), Long.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(16), enumValueDescriptor.findValueByNumber(1)) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(16), enumValueDescriptor.findValueByNumber(2)) + .build(); + + DynamicMessage repeatedMessage2 = DynamicMessage + .newBuilder(repeatedMessageDescriptor) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(1), true) + .build(); + + DynamicMessage rootMessage = DynamicMessage + .newBuilder(messageDescriptor) + .addRepeatedField(messageDescriptor.findFieldByNumber(1), repeatedMessage1) + .addRepeatedField(messageDescriptor.findFieldByNumber(1), repeatedMessage2) + .build(); + + return rootMessage.toByteString().newInput(); + } + public static InputStream generateInputDataForProto2() throws IOException, Descriptors.DescriptorValidationException { DescriptorProtos.FileDescriptorSet anyDescriptorSet = DescriptorProtos.FileDescriptorSet.parseFrom(new FileInputStream(BASE_TEST_PATH + "google/protobuf/any.desc")); Descriptors.FileDescriptor anyDesc = Descriptors.FileDescriptor.buildFrom(anyDescriptorSet.getFile(0), new Descriptors.FileDescriptor[]{}); diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/converter/TestProtobufDataConverter.java b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/converter/TestProtobufDataConverter.java index 9b4bdabe78..7aeafa895a 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/converter/TestProtobufDataConverter.java +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/converter/TestProtobufDataConverter.java @@ -20,7 +20,6 @@ import com.google.protobuf.Descriptors; import com.squareup.wire.schema.Schema; import org.apache.nifi.serialization.record.MapRecord; import org.apache.nifi.serialization.record.RecordSchema; -import org.apache.nifi.serialization.record.util.DataTypeUtils; import org.apache.nifi.services.protobuf.ProtoTestUtil; import org.apache.nifi.services.protobuf.schema.ProtoSchemaParser; import org.junit.jupiter.api.Test; @@ -31,6 +30,7 @@ import java.util.Map; import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadProto2TestSchema; import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadProto3TestSchema; +import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadRepeatedProto3TestSchema; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; @@ -58,22 +58,58 @@ public class TestProtobufDataConverter { assertEquals(Float.MAX_VALUE, record.getValue("floatField")); assertArrayEquals("Test bytes".getBytes(), (byte[]) record.getValue("bytesField")); assertEquals(Long.MAX_VALUE, record.getValue("int64Field")); - assertEquals(new BigInteger("18446744073709551615"), DataTypeUtils.toBigInt(record.getValue("uint64Field"), "field12")); + assertEquals(new BigInteger("18446744073709551615"), record.getValue("uint64Field")); assertEquals(Long.MIN_VALUE, record.getValue("sint64Field")); - assertEquals(new BigInteger("18446744073709551614"), DataTypeUtils.toBigInt(record.getValue("fixed64Field"), "field14")); + assertEquals(new BigInteger("18446744073709551614"), record.getValue("fixed64Field")); assertEquals(Long.MAX_VALUE, record.getValue("sfixed64Field")); final MapRecord nestedRecord = (MapRecord) record.getValue("nestedMessage"); assertEquals("ENUM_VALUE_3", nestedRecord.getValue("testEnum")); - assertArrayEquals(new Object[]{"Repeated 1", "Repeated 2", "Repeated 3"}, (Object[]) nestedRecord.getValue("repeatedField")); + assertEquals(Map.of("test_key_entry1", 101, "test_key_entry2", 202), nestedRecord.getValue("testMap")); // assert only one field is set in the OneOf field assertNull(nestedRecord.getValue("stringOption")); assertNull(nestedRecord.getValue("booleanOption")); assertEquals(3, nestedRecord.getValue("int32Option")); + } - assertEquals(Map.of("test_key_entry1", 101, "test_key_entry2", 202), nestedRecord.getValue("testMap")); + @Test + public void testDataConverterForRepeatedProto3() throws Descriptors.DescriptorValidationException, IOException { + final Schema schema = loadRepeatedProto3TestSchema(); + final RecordSchema recordSchema = new ProtoSchemaParser(schema).createSchema("RootMessage"); + + final ProtobufDataConverter dataConverter = new ProtobufDataConverter(schema, "RootMessage", recordSchema, false, false); + final MapRecord record = dataConverter.createRecord(ProtoTestUtil.generateInputDataForRepeatedProto3()); + + final Object[] repeatedMessage = (Object[]) record.getValue("repeatedMessage"); + final MapRecord record1 = (MapRecord) repeatedMessage[0]; + + assertArrayEquals(new Object[]{true, false}, (Object[]) record1.getValue("booleanField")); + assertArrayEquals(new Object[]{"Test text1", "Test text2"}, (Object[]) record1.getValue("stringField")); + assertArrayEquals(new Object[]{Integer.MAX_VALUE, Integer.MAX_VALUE - 1}, (Object[]) record1.getValue("int32Field")); + assertArrayEquals(new Object[]{4294967295L, 4294967294L}, (Object[]) record1.getValue("uint32Field")); + assertArrayEquals(new Object[]{Integer.MIN_VALUE, Integer.MIN_VALUE + 1}, (Object[]) record1.getValue("sint32Field")); + assertArrayEquals(new Object[]{4294967294L, 4294967293L}, (Object[]) record1.getValue("fixed32Field")); + assertArrayEquals(new Object[]{Integer.MAX_VALUE, Integer.MAX_VALUE - 1}, (Object[]) record1.getValue("sfixed32Field")); + assertArrayEquals(new Object[]{Double.MAX_VALUE, Double.MAX_VALUE - 1}, (Object[]) record1.getValue("doubleField")); + assertArrayEquals(new Object[]{Float.MAX_VALUE, Float.MAX_VALUE - 1}, (Object[]) record1.getValue("floatField")); + assertArrayEquals(new Object[]{Long.MAX_VALUE, Long.MAX_VALUE - 1}, (Object[]) record1.getValue("int64Field")); + assertArrayEquals(new Object[]{Long.MIN_VALUE, Long.MIN_VALUE + 1}, (Object[]) record1.getValue("sint64Field")); + assertArrayEquals(new Object[]{Long.MAX_VALUE, Long.MAX_VALUE - 1}, (Object[]) record1.getValue("sfixed64Field")); + assertArrayEquals(new Object[]{"ENUM_VALUE_2", "ENUM_VALUE_3"}, (Object[]) record1.getValue("testEnum")); + + final Object[] uint64FieldValues = (Object[]) record1.getValue("uint64Field"); + assertEquals(new BigInteger("18446744073709551615"), uint64FieldValues[0]); + assertEquals(new BigInteger("18446744073709551614"), uint64FieldValues[1]); + + final Object[] bytesFieldValues = (Object[]) record1.getValue("bytesField"); + assertArrayEquals("Test bytes1".getBytes(), (byte[]) bytesFieldValues[0]); + assertArrayEquals("Test bytes2".getBytes(), (byte[]) bytesFieldValues[1]); + + final MapRecord record2 = (MapRecord) repeatedMessage[1]; + + assertArrayEquals(new Object[]{true}, (Object[]) record2.getValue("booleanField")); } @Test diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/schema/TestProtoSchemaParser.java b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/schema/TestProtoSchemaParser.java index d313bb595c..42bb858eab 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/schema/TestProtoSchemaParser.java +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/schema/TestProtoSchemaParser.java @@ -20,12 +20,15 @@ import org.apache.nifi.serialization.SimpleRecordSchema; import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordFieldType; import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.serialization.record.type.ArrayDataType; +import org.apache.nifi.serialization.record.type.RecordDataType; import org.junit.jupiter.api.Test; import java.util.Arrays; import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadProto2TestSchema; import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadProto3TestSchema; +import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadRepeatedProto3TestSchema; import static org.junit.jupiter.api.Assertions.assertEquals; public class TestProtoSchemaParser { @@ -52,7 +55,6 @@ public class TestProtoSchemaParser { new RecordField("sfixed64Field", RecordFieldType.LONG.getDataType()), new RecordField("nestedMessage", RecordFieldType.RECORD.getRecordDataType(new SimpleRecordSchema(Arrays.asList( new RecordField("testEnum", RecordFieldType.ENUM.getEnumDataType(Arrays.asList("ENUM_VALUE_1", "ENUM_VALUE_2", "ENUM_VALUE_3"))), - new RecordField("repeatedField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.STRING.getDataType())), new RecordField("testMap", RecordFieldType.MAP.getMapDataType(RecordFieldType.INT.getDataType())), new RecordField("stringOption", RecordFieldType.STRING.getDataType()), new RecordField("booleanOption", RecordFieldType.BOOLEAN.getDataType()), @@ -64,6 +66,38 @@ public class TestProtoSchemaParser { assertEquals(expected, actual); } + @Test + public void testSchemaParserForRepeatedProto3() { + final ProtoSchemaParser schemaParser = new ProtoSchemaParser(loadRepeatedProto3TestSchema()); + + final SimpleRecordSchema expected = + new SimpleRecordSchema(Arrays.asList( + new RecordField("booleanField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BOOLEAN.getDataType())), + new RecordField("stringField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.STRING.getDataType())), + new RecordField("int32Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.INT.getDataType())), + new RecordField("uint32Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("sint32Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("fixed32Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("sfixed32Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.INT.getDataType())), + new RecordField("doubleField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.DOUBLE.getDataType())), + new RecordField("floatField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.FLOAT.getDataType())), + new RecordField("bytesField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BYTE.getDataType()))), + new RecordField("int64Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("uint64Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BIGINT.getDataType())), + new RecordField("sint64Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("fixed64Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BIGINT.getDataType())), + new RecordField("sfixed64Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("testEnum", RecordFieldType.ARRAY.getArrayDataType( + RecordFieldType.ENUM.getEnumDataType(Arrays.asList("ENUM_VALUE_1", "ENUM_VALUE_2", "ENUM_VALUE_3")))) + )); + + final RecordSchema actual = schemaParser.createSchema("RootMessage"); + final ArrayDataType arrayDataType = (ArrayDataType) actual.getField("repeatedMessage").get().getDataType(); + final RecordDataType recordDataType = (RecordDataType) arrayDataType.getElementType(); + + assertEquals(expected, recordDataType.getChildSchema()); + } + @Test public void testSchemaParserForProto2() { final ProtoSchemaParser schemaParser = new ProtoSchemaParser(loadProto2TestSchema()); diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.desc b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.desc index a2316f3f87c9c701a5b9f2acf469341b693338e5..1dbfb60613870865c1b38d2d89bafe68f7a00090 100644 GIT binary patch delta 42 ycmeyzeuJHf>*_|PL?*_SlT(;1cucw2OHzwVd=m>KL?=5iYfV1LR5Uq@c?tj_i4AZ7 delta 81 zcmcb?{*RrB>-R>cL?*`blT(;1R8_cmi&6_x6H8K4+%i*hQY1tbm^C testMap = 21; oneof oneOfField { string stringOption = 22; bool booleanOption = 23; int32 int32Option = 24; } - map testMap = 25; } enum TestEnum { diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_repeated_proto3.desc b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_repeated_proto3.desc new file mode 100644 index 0000000000000000000000000000000000000000..70811cb38850842be56ca0494621f69464eba24e GIT binary patch literal 755 zcmZ9KPfNov7>C{e?rpcmoQ4M(LqR+Y?8Jj7Jxmashv2-HmS&?+o3UmvzndS*kKo%T znQ2dj=l3*k@=M^i2N#kTGTieWk0ejR-Cjtsm{*_KE4WMux#C3;Z8?1e-*_5LTGP&r z8<08E|7*^g;_RmjenX-&@M0k{9_247&ys2}ht(LJf?LRans1fT#(6_#BYFvt=5i4z zp@Tw?S)^RyU$YK&gNsHjN;lzdi)<*Y;; z;*6{9I_R{RjmqB9&f3SHHQ zWt4|E#t03n`Y?v+!zg1^`nnI}ocb`5j}kF`80#9G@&PXM!`FnMFb;