From cf49a58ee75601e1d0d7512104b9ed0ca2e8ec41 Mon Sep 17 00:00:00 2001 From: Wesley-Lawrence Date: Sun, 30 Jul 2017 14:49:00 -0400 Subject: [PATCH] NIFI-4215 Allow Complex Avro Schema Parsing NiFi can now parse an Avro schema of a record that references an already defined record, including itself. Signed-off-by: James Wing This closes #2034. --- .../serialization/SimpleRecordSchema.java | 59 +++++--- .../org/apache/nifi/avro/AvroTypeUtil.java | 63 +++++--- .../apache/nifi/avro/TestAvroTypeUtil.java | 139 ++++++++++++++++++ 3 files changed, 221 insertions(+), 40 deletions(-) diff --git a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/SimpleRecordSchema.java b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/SimpleRecordSchema.java index 017aef1e38..871c7bf74f 100644 --- a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/SimpleRecordSchema.java +++ b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/SimpleRecordSchema.java @@ -32,8 +32,8 @@ import org.apache.nifi.serialization.record.RecordSchema; import org.apache.nifi.serialization.record.SchemaIdentifier; public class SimpleRecordSchema implements RecordSchema { - private final List fields; - private final Map fieldIndices; + private List fields = null; + private Map fieldIndices = null; private final boolean textAvailable; private final String text; private final String schemaFormat; @@ -47,34 +47,24 @@ public class SimpleRecordSchema implements RecordSchema { this(fields, createText(fields), null, false, id); } + public SimpleRecordSchema(final String text, final String schemaFormat, final SchemaIdentifier id) { + this(text, schemaFormat, true, id); + } + public SimpleRecordSchema(final List fields, final String text, final String schemaFormat, final SchemaIdentifier id) { this(fields, text, schemaFormat, true, id); } private SimpleRecordSchema(final List fields, final String text, final String schemaFormat, final boolean textAvailable, final SchemaIdentifier id) { + this(text, schemaFormat, textAvailable, id); + setFields(fields); + } + + private SimpleRecordSchema(final String text, final String schemaFormat, final boolean textAvailable, final SchemaIdentifier id) { this.text = text; this.schemaFormat = schemaFormat; this.schemaIdentifier = id; this.textAvailable = textAvailable; - this.fields = Collections.unmodifiableList(new ArrayList<>(fields)); - this.fieldIndices = new HashMap<>(fields.size()); - - int index = 0; - for (final RecordField field : fields) { - Integer previousValue = fieldIndices.put(field.getFieldName(), index); - if (previousValue != null) { - throw new IllegalArgumentException("Two fields are given with the same name (or alias) of '" + field.getFieldName() + "'"); - } - - for (final String alias : field.getAliases()) { - previousValue = fieldIndices.put(alias, index); - if (previousValue != null) { - throw new IllegalArgumentException("Two fields are given with the same name (or alias) of '" + field.getFieldName() + "'"); - } - } - - index++; - } } @Override @@ -97,6 +87,33 @@ public class SimpleRecordSchema implements RecordSchema { return fields; } + public void setFields(final List fields) { + + if (this.fields != null) { + throw new IllegalArgumentException("Fields have already been set."); + } + + this.fields = Collections.unmodifiableList(new ArrayList<>(fields)); + this.fieldIndices = new HashMap<>(fields.size()); + + int index = 0; + for (final RecordField field : fields) { + Integer previousValue = fieldIndices.put(field.getFieldName(), index); + if (previousValue != null) { + throw new IllegalArgumentException("Two fields are given with the same name (or alias) of '" + field.getFieldName() + "'"); + } + + for (final String alias : field.getAliases()) { + previousValue = fieldIndices.put(alias, index); + if (previousValue != null) { + throw new IllegalArgumentException("Two fields are given with the same name (or alias) of '" + field.getFieldName() + "'"); + } + } + + index++; + } + } + @Override public int getFieldCount() { return fields.size(); diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-avro-record-utils/src/main/java/org/apache/nifi/avro/AvroTypeUtil.java b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-avro-record-utils/src/main/java/org/apache/nifi/avro/AvroTypeUtil.java index a39a7f4006..4797916c57 100644 --- a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-avro-record-utils/src/main/java/org/apache/nifi/avro/AvroTypeUtil.java +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-avro-record-utils/src/main/java/org/apache/nifi/avro/AvroTypeUtil.java @@ -218,6 +218,15 @@ public class AvroTypeUtil { * @return a Data Type that corresponds to the given Avro Schema */ public static DataType determineDataType(final Schema avroSchema) { + return determineDataType(avroSchema, new HashMap<>()); + } + + public static DataType determineDataType(final Schema avroSchema, Map knownRecordTypes) { + + if (knownRecordTypes == null) { + throw new IllegalArgumentException("'knownRecordTypes' cannot be null."); + } + final Type avroType = avroSchema.getType(); final LogicalType logicalType = avroSchema.getLogicalType(); @@ -241,7 +250,7 @@ public class AvroTypeUtil { switch (avroType) { case ARRAY: - return RecordFieldType.ARRAY.getArrayDataType(determineDataType(avroSchema.getElementType())); + return RecordFieldType.ARRAY.getArrayDataType(determineDataType(avroSchema.getElementType(), knownRecordTypes)); case BYTES: case FIXED: return RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BYTE.getDataType()); @@ -259,40 +268,50 @@ public class AvroTypeUtil { case LONG: return RecordFieldType.LONG.getDataType(); case RECORD: { - final List avroFields = avroSchema.getFields(); - final List recordFields = new ArrayList<>(avroFields.size()); + String schemaFullName = avroSchema.getNamespace() + "." + avroSchema.getName(); - for (final Field field : avroFields) { - final String fieldName = field.name(); - final Schema fieldSchema = field.schema(); - final DataType fieldType = determineDataType(fieldSchema); + if (knownRecordTypes.containsKey(schemaFullName)) { + return knownRecordTypes.get(schemaFullName); + } else { + SimpleRecordSchema recordSchema = new SimpleRecordSchema(avroSchema.toString(), AVRO_SCHEMA_FORMAT, SchemaIdentifier.EMPTY); + DataType recordSchemaType = RecordFieldType.RECORD.getRecordDataType(recordSchema); + knownRecordTypes.put(schemaFullName, recordSchemaType); - if (field.defaultVal() == JsonProperties.NULL_VALUE) { - recordFields.add(new RecordField(fieldName, fieldType, field.aliases())); - } else { - recordFields.add(new RecordField(fieldName, fieldType, field.defaultVal(), field.aliases())); + final List avroFields = avroSchema.getFields(); + final List recordFields = new ArrayList<>(avroFields.size()); + + for (final Field field : avroFields) { + final String fieldName = field.name(); + final Schema fieldSchema = field.schema(); + final DataType fieldType = determineDataType(fieldSchema, knownRecordTypes); + + if (field.defaultVal() == JsonProperties.NULL_VALUE) { + recordFields.add(new RecordField(fieldName, fieldType, field.aliases())); + } else { + recordFields.add(new RecordField(fieldName, fieldType, field.defaultVal(), field.aliases())); + } } - } - final RecordSchema recordSchema = new SimpleRecordSchema(recordFields, avroSchema.toString(), AVRO_SCHEMA_FORMAT, SchemaIdentifier.EMPTY); - return RecordFieldType.RECORD.getRecordDataType(recordSchema); + recordSchema.setFields(recordFields); + return recordSchemaType; + } } case NULL: return RecordFieldType.STRING.getDataType(); case MAP: final Schema valueSchema = avroSchema.getValueType(); - final DataType valueType = determineDataType(valueSchema); + final DataType valueType = determineDataType(valueSchema, knownRecordTypes); return RecordFieldType.MAP.getMapDataType(valueType); case UNION: { final List nonNullSubSchemas = getNonNullSubSchemas(avroSchema); if (nonNullSubSchemas.size() == 1) { - return determineDataType(nonNullSubSchemas.get(0)); + return determineDataType(nonNullSubSchemas.get(0), knownRecordTypes); } final List possibleChildTypes = new ArrayList<>(nonNullSubSchemas.size()); for (final Schema subSchema : nonNullSubSchemas) { - final DataType childDataType = determineDataType(subSchema); + final DataType childDataType = determineDataType(subSchema, knownRecordTypes); possibleChildTypes.add(childDataType); } @@ -334,10 +353,16 @@ public class AvroTypeUtil { throw new IllegalArgumentException("Avro Schema cannot be null"); } + String schemaFullName = avroSchema.getNamespace() + "." + avroSchema.getName(); + SimpleRecordSchema recordSchema = new SimpleRecordSchema(avroSchema.toString(), AVRO_SCHEMA_FORMAT, SchemaIdentifier.EMPTY); + DataType recordSchemaType = RecordFieldType.RECORD.getRecordDataType(recordSchema); + Map knownRecords = new HashMap<>(); + knownRecords.put(schemaFullName, recordSchemaType); + final List recordFields = new ArrayList<>(avroSchema.getFields().size()); for (final Field field : avroSchema.getFields()) { final String fieldName = field.name(); - final DataType dataType = AvroTypeUtil.determineDataType(field.schema()); + final DataType dataType = AvroTypeUtil.determineDataType(field.schema(), knownRecords); if (field.defaultVal() == JsonProperties.NULL_VALUE) { recordFields.add(new RecordField(fieldName, dataType, field.aliases())); @@ -346,7 +371,7 @@ public class AvroTypeUtil { } } - final RecordSchema recordSchema = new SimpleRecordSchema(recordFields, schemaText, AVRO_SCHEMA_FORMAT, schemaId); + recordSchema.setFields(recordFields); return recordSchema; } diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-avro-record-utils/src/test/java/org/apache/nifi/avro/TestAvroTypeUtil.java b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-avro-record-utils/src/test/java/org/apache/nifi/avro/TestAvroTypeUtil.java index fe1973351f..b0178298a5 100644 --- a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-avro-record-utils/src/test/java/org/apache/nifi/avro/TestAvroTypeUtil.java +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-avro-record-utils/src/test/java/org/apache/nifi/avro/TestAvroTypeUtil.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertTrue; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; import org.apache.avro.Schema; import org.apache.avro.Schema.Field; @@ -33,6 +34,8 @@ import org.apache.nifi.serialization.record.DataType; import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordFieldType; import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.serialization.record.type.RecordDataType; +import org.junit.Assert; import org.junit.Test; public class TestAvroTypeUtil { @@ -100,4 +103,140 @@ public class TestAvroTypeUtil { assertEquals(Collections.singleton("greeting"), stringField.getAliases()); } + @Test + // Simple recursion is a record A composing itself (similar to a LinkedList Node referencing 'next') + public void testSimpleRecursiveSchema() { + Schema recursiveSchema = new Schema.Parser().parse( + "{\n" + + " \"namespace\": \"org.apache.nifi.testing\",\n" + + " \"name\": \"NodeRecord\",\n" + + " \"type\": \"record\",\n" + + " \"fields\": [\n" + + " {\n" + + " \"name\": \"id\",\n" + + " \"type\": \"int\"\n" + + " },\n" + + " {\n" + + " \"name\": \"value\",\n" + + " \"type\": \"string\"\n" + + " },\n" + + " {\n" + + " \"name\": \"parent\",\n" + + " \"type\": [\n" + + " \"null\",\n" + + " \"NodeRecord\"\n" + + " ]\n" + + " }\n" + + " ]\n" + + "}\n" + ); + + // Make sure the following doesn't throw an exception + RecordSchema result = AvroTypeUtil.createSchema(recursiveSchema); + + // Make sure it parsed correctly + Assert.assertEquals(3, result.getFieldCount()); + + Optional idField = result.getField("id"); + Assert.assertTrue(idField.isPresent()); + Assert.assertEquals(RecordFieldType.INT, idField.get().getDataType().getFieldType()); + + Optional valueField = result.getField("value"); + Assert.assertTrue(valueField.isPresent()); + Assert.assertEquals(RecordFieldType.STRING, valueField.get().getDataType().getFieldType()); + + Optional parentField = result.getField("parent"); + Assert.assertTrue(parentField.isPresent()); + Assert.assertEquals(RecordFieldType.RECORD, parentField.get().getDataType().getFieldType()); + + // The 'parent' field should have a circular schema reference to the top level record schema, similar to how Avro handles this + Assert.assertEquals(result, ((RecordDataType)parentField.get().getDataType()).getChildSchema()); + } + + @Test + // Complicated recursion is a record A composing record B, who composes a record A + public void testComplicatedRecursiveSchema() { + Schema recursiveSchema = new Schema.Parser().parse( + "{\n" + + " \"namespace\": \"org.apache.nifi.testing\",\n" + + " \"name\": \"Record_A\",\n" + + " \"type\": \"record\",\n" + + " \"fields\": [\n" + + " {\n" + + " \"name\": \"id\",\n" + + " \"type\": \"int\"\n" + + " },\n" + + " {\n" + + " \"name\": \"value\",\n" + + " \"type\": \"string\"\n" + + " },\n" + + " {\n" + + " \"name\": \"child\",\n" + + " \"type\": {\n" + + " \"namespace\": \"org.apache.nifi.testing\",\n" + + " \"name\": \"Record_B\",\n" + + " \"type\": \"record\",\n" + + " \"fields\": [\n" + + " {\n" + + " \"name\": \"id\",\n" + + " \"type\": \"int\"\n" + + " },\n" + + " {\n" + + " \"name\": \"value\",\n" + + " \"type\": \"string\"\n" + + " },\n" + + " {\n" + + " \"name\": \"parent\",\n" + + " \"type\": [\n" + + " \"null\",\n" + + " \"Record_A\"\n" + + " ]\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + "}\n" + ); + + // Make sure the following doesn't throw an exception + RecordSchema recordASchema = AvroTypeUtil.createSchema(recursiveSchema); + + // Make sure it parsed correctly + Assert.assertEquals(3, recordASchema.getFieldCount()); + + Optional recordAIdField = recordASchema.getField("id"); + Assert.assertTrue(recordAIdField.isPresent()); + Assert.assertEquals(RecordFieldType.INT, recordAIdField.get().getDataType().getFieldType()); + + Optional recordAValueField = recordASchema.getField("value"); + Assert.assertTrue(recordAValueField.isPresent()); + Assert.assertEquals(RecordFieldType.STRING, recordAValueField.get().getDataType().getFieldType()); + + Optional recordAChildField = recordASchema.getField("child"); + Assert.assertTrue(recordAChildField.isPresent()); + Assert.assertEquals(RecordFieldType.RECORD, recordAChildField.get().getDataType().getFieldType()); + + // Get the child schema + RecordSchema recordBSchema = ((RecordDataType)recordAChildField.get().getDataType()).getChildSchema(); + + // Make sure it parsed correctly + Assert.assertEquals(3, recordBSchema.getFieldCount()); + + Optional recordBIdField = recordBSchema.getField("id"); + Assert.assertTrue(recordBIdField.isPresent()); + Assert.assertEquals(RecordFieldType.INT, recordBIdField.get().getDataType().getFieldType()); + + Optional recordBValueField = recordBSchema.getField("value"); + Assert.assertTrue(recordBValueField.isPresent()); + Assert.assertEquals(RecordFieldType.STRING, recordBValueField.get().getDataType().getFieldType()); + + Optional recordBParentField = recordBSchema.getField("parent"); + Assert.assertTrue(recordBParentField.isPresent()); + Assert.assertEquals(RecordFieldType.RECORD, recordBParentField.get().getDataType().getFieldType()); + + // Make sure the 'parent' field has a schema reference back to the original top level record schema + Assert.assertEquals(recordASchema, ((RecordDataType)recordBParentField.get().getDataType()).getChildSchema()); + } + }