From ad95287e782d71a778dfa89b3cb2a17637b32bf5 Mon Sep 17 00:00:00 2001 From: Tamas Palfy Date: Thu, 18 Jun 2020 19:42:35 +0200 Subject: [PATCH] NIFI-6934 In PutDatabaseRecord added DatabaseAdapter-based UPSERT support for Postgres (9.5+) NIFI-6934 Added more documentation and unit tests. NIFI-6934 Added missing license for new test class. Signed-off-by: Matthew Burgess This closes #4350 --- .../standard/PutDatabaseRecord.java | 214 ++++++++++++++---- .../standard/db/DatabaseAdapter.java | 27 +++ .../db/impl/PostgreSQLDatabaseAdapter.java | 75 ++++++ ...ifi.processors.standard.db.DatabaseAdapter | 3 +- .../impl/TestPostgreSQLDatabaseAdapter.java | 108 +++++++++ 5 files changed, 377 insertions(+), 50 deletions(-) create mode 100644 nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/impl/PostgreSQLDatabaseAdapter.java create mode 100644 nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/db/impl/TestPostgreSQLDatabaseAdapter.java diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java index 926c5cd7bc..d6384d24f0 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutDatabaseRecord.java @@ -29,6 +29,8 @@ import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.lifecycle.OnScheduled; import org.apache.nifi.components.AllowableValue; import org.apache.nifi.components.PropertyDescriptor; +import org.apache.nifi.components.ValidationContext; +import org.apache.nifi.components.ValidationResult; import org.apache.nifi.dbcp.DBCPService; import org.apache.nifi.expression.AttributeExpression; import org.apache.nifi.expression.ExpressionLanguageScope; @@ -47,6 +49,7 @@ import org.apache.nifi.processor.util.pattern.PartialFunctions; import org.apache.nifi.processor.util.pattern.Put; import org.apache.nifi.processor.util.pattern.RollbackOnFailure; import org.apache.nifi.processor.util.pattern.RoutingResult; +import org.apache.nifi.processors.standard.db.DatabaseAdapter; import org.apache.nifi.serialization.MalformedRecordException; import org.apache.nifi.serialization.RecordReader; import org.apache.nifi.serialization.RecordReaderFactory; @@ -70,11 +73,13 @@ import java.sql.SQLIntegrityConstraintViolationException; import java.sql.SQLNonTransientException; import java.sql.Statement; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -101,6 +106,7 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { static final String UPDATE_TYPE = "UPDATE"; static final String INSERT_TYPE = "INSERT"; static final String DELETE_TYPE = "DELETE"; + static final String UPSERT_TYPE = "UPSERT"; static final String SQL_TYPE = "SQL"; // Not an allowable value in the Statement Type property, must be set by attribute static final String USE_ATTR_TYPE = "Use statement.type Attribute"; @@ -152,11 +158,14 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { static final PropertyDescriptor STATEMENT_TYPE = new PropertyDescriptor.Builder() .name("put-db-record-statement-type") .displayName("Statement Type") - .description("Specifies the type of SQL Statement to generate. If 'Use statement.type Attribute' is chosen, then the value is taken from the statement.type attribute in the " + .description("Specifies the type of SQL Statement to generate. " + + "Please refer to the database documentation for a description of the behavior of each operation. " + + "Please note that some Database Types may not support certain Statement Types. " + + "If 'Use statement.type Attribute' is chosen, then the value is taken from the statement.type attribute in the " + "FlowFile. The 'Use statement.type Attribute' option is the only one that allows the 'SQL' statement type. If 'SQL' is specified, the value of the field specified by the " + "'Field Containing SQL' property is expected to be a valid SQL statement on the target database, and will be executed as-is.") .required(true) - .allowableValues(UPDATE_TYPE, INSERT_TYPE, DELETE_TYPE, USE_ATTR_TYPE) + .allowableValues(UPDATE_TYPE, INSERT_TYPE, UPSERT_TYPE, DELETE_TYPE, USE_ATTR_TYPE) .build(); static final PropertyDescriptor DBCP_SERVICE = new PropertyDescriptor.Builder() @@ -299,11 +308,34 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { .expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES) .build(); + static final PropertyDescriptor DB_TYPE; + + protected static final Map dbAdapters; + protected static List propDescriptors; private Cache schemaCache; static { + dbAdapters = new HashMap<>(); + ArrayList dbAdapterValues = new ArrayList<>(); + + ServiceLoader dbAdapterLoader = ServiceLoader.load(DatabaseAdapter.class); + dbAdapterLoader.forEach(databaseAdapter -> { + dbAdapters.put(databaseAdapter.getName(), databaseAdapter); + dbAdapterValues.add(new AllowableValue(databaseAdapter.getName(), databaseAdapter.getName(), databaseAdapter.getDescription())); + }); + + DB_TYPE = new PropertyDescriptor.Builder() + .name("db-type") + .displayName("Database Type") + .description("The type/flavor of database, used for generating database-specific code. In many cases the Generic type " + + "should suffice, but some databases (such as Oracle) require custom SQL clauses. ") + .allowableValues(dbAdapterValues.toArray(new AllowableValue[dbAdapterValues.size()])) + .defaultValue("Generic") + .required(false) + .build(); + final Set r = new HashSet<>(); r.add(REL_SUCCESS); r.add(REL_FAILURE); @@ -312,6 +344,7 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { final List pds = new ArrayList<>(); pds.add(RECORD_READER_FACTORY); + pds.add(DB_TYPE); pds.add(STATEMENT_TYPE); pds.add(DBCP_SERVICE); pds.add(CATALOG_NAME); @@ -335,6 +368,7 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { private Put process; private ExceptionHandler exceptionHandler; + private DatabaseAdapter databaseAdapter; @Override public Set getRelationships() { @@ -444,9 +478,29 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { }); }; + @Override + protected Collection customValidate(ValidationContext validationContext) { + Collection validationResults = new ArrayList<>(super.customValidate(validationContext)); + + DatabaseAdapter databaseAdapter = dbAdapters.get(validationContext.getProperty(DB_TYPE).getValue()); + String statementType = validationContext.getProperty(STATEMENT_TYPE).getValue(); + + if (UPSERT_TYPE.equals(statementType) && !databaseAdapter.supportsUpsert()) { + validationResults.add(new ValidationResult.Builder() + .subject(STATEMENT_TYPE.getDisplayName()) + .valid(false) + .explanation(databaseAdapter.getName() + " does not support " + statementType) + .build() + ); + } + + return validationResults; + } @OnScheduled public void onScheduled(final ProcessContext context) { + databaseAdapter = dbAdapters.get(context.getProperty(DB_TYPE).getValue()); + final int tableSchemaCacheSize = context.getProperty(TABLE_SCHEMA_CACHE_SIZE).asInteger(); schemaCache = Caffeine.newBuilder() .maximumSize(tableSchemaCacheSize) @@ -657,6 +711,9 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { } else if (DELETE_TYPE.equalsIgnoreCase(statementType)) { sqlHolder = generateDelete(recordSchema, fqTableName, tableSchema, settings); + } else if (UPSERT_TYPE.equalsIgnoreCase(statementType)) { + sqlHolder = generateUpsert(recordSchema, fqTableName, updateKeys, tableSchema, settings); + } else { throw new IllegalArgumentException(format("Statement Type %s is not valid, FlowFile %s", statementType, flowFile)); } @@ -790,20 +847,7 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { SqlAndIncludedColumns generateInsert(final RecordSchema recordSchema, final String tableName, final TableSchema tableSchema, final DMLSettings settings) throws IllegalArgumentException, SQLException { - final Set normalizedFieldNames = getNormalizedColumnNames(recordSchema, settings.translateFieldNames); - - for (final String requiredColName : tableSchema.getRequiredColumnNames()) { - final String normalizedColName = normalizeColumnName(requiredColName, settings.translateFieldNames); - if (!normalizedFieldNames.contains(normalizedColName)) { - String missingColMessage = "Record does not have a value for the Required column '" + requiredColName + "'"; - if (settings.failUnmappedColumns) { - getLogger().error(missingColMessage); - throw new IllegalArgumentException(missingColMessage); - } else if (settings.warningUnmappedColumns) { - getLogger().warn(missingColMessage); - } - } - } + checkValuesForRequiredColumns(recordSchema, tableSchema, settings); final StringBuilder sqlBuilder = new StringBuilder(); sqlBuilder.append("INSERT INTO "); @@ -854,47 +898,59 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { return new SqlAndIncludedColumns(sqlBuilder.toString(), includedColumns); } + SqlAndIncludedColumns generateUpsert(final RecordSchema recordSchema, final String tableName, final String updateKeys, + final TableSchema tableSchema, final DMLSettings settings) + throws IllegalArgumentException, SQLException, MalformedRecordException { + + checkValuesForRequiredColumns(recordSchema, tableSchema, settings); + + Set keyColumnNames = getUpdateKeyColumnNames(tableName, updateKeys, tableSchema); + Set normalizedKeyColumnNames = normalizeKeyColumnNamesAndCheckForValues(recordSchema, updateKeys, settings, keyColumnNames); + + List usedColumnNames = new ArrayList<>(); + List usedColumnIndices = new ArrayList<>(); + + List fieldNames = recordSchema.getFieldNames(); + if (fieldNames != null) { + int fieldCount = fieldNames.size(); + + for (int i = 0; i < fieldCount; i++) { + RecordField field = recordSchema.getField(i); + String fieldName = field.getFieldName(); + + final ColumnDescription desc = tableSchema.getColumns().get(normalizeColumnName(fieldName, settings.translateFieldNames)); + if (desc == null && !settings.ignoreUnmappedFields) { + throw new SQLDataException("Cannot map field '" + fieldName + "' to any column in the database"); + } + + if (desc != null) { + if (settings.escapeColumnNames) { + usedColumnNames.add(tableSchema.getQuotedIdentifierString() + desc.getColumnName() + tableSchema.getQuotedIdentifierString()); + } else { + usedColumnNames.add(desc.getColumnName()); + } + usedColumnIndices.add(i); + } + } + } + + String sql = databaseAdapter.getUpsertStatement(tableName, usedColumnNames, normalizedKeyColumnNames); + + return new SqlAndIncludedColumns(sql, usedColumnIndices); + } + SqlAndIncludedColumns generateUpdate(final RecordSchema recordSchema, final String tableName, final String updateKeys, final TableSchema tableSchema, final DMLSettings settings) throws IllegalArgumentException, MalformedRecordException, SQLException { - final Set updateKeyNames; - if (updateKeys == null) { - updateKeyNames = tableSchema.getPrimaryKeyColumnNames(); - } else { - updateKeyNames = new HashSet<>(); - for (final String updateKey : updateKeys.split(",")) { - updateKeyNames.add(updateKey.trim()); - } - } - if (updateKeyNames.isEmpty()) { - throw new SQLIntegrityConstraintViolationException("Table '" + tableName + "' does not have a Primary Key and no Update Keys were specified"); - } + final Set keyColumnNames = getUpdateKeyColumnNames(tableName, updateKeys, tableSchema); + final Set normalizedKeyColumnNames = normalizeKeyColumnNamesAndCheckForValues(recordSchema, updateKeys, settings, keyColumnNames); final StringBuilder sqlBuilder = new StringBuilder(); sqlBuilder.append("UPDATE "); sqlBuilder.append(tableName); - // Create a Set of all normalized Update Key names, and ensure that there is a field in the record - // for each of the Update Key fields. - final Set normalizedFieldNames = getNormalizedColumnNames(recordSchema, settings.translateFieldNames); - final Set normalizedUpdateNames = new HashSet<>(); - for (final String uk : updateKeyNames) { - final String normalizedUK = normalizeColumnName(uk, settings.translateFieldNames); - normalizedUpdateNames.add(normalizedUK); - - if (!normalizedFieldNames.contains(normalizedUK)) { - String missingColMessage = "Record does not have a value for the " + (updateKeys == null ? "Primary" : "Update") + "Key column '" + uk + "'"; - if (settings.failUnmappedColumns) { - getLogger().error(missingColMessage); - throw new MalformedRecordException(missingColMessage); - } else if (settings.warningUnmappedColumns) { - getLogger().warn(missingColMessage); - } - } - } - // iterate over all of the fields in the record, building the SQL statement by adding the column names List fieldNames = recordSchema.getFieldNames(); final List includedColumns = new ArrayList<>(); @@ -920,7 +976,7 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { // Check if this column is an Update Key. If so, skip it for now. We will come // back to it after we finish the SET clause - if (!normalizedUpdateNames.contains(normalizedColName)) { + if (!normalizedKeyColumnNames.contains(normalizedColName)) { if (fieldsFound.getAndIncrement() > 0) { sqlBuilder.append(", "); } @@ -952,7 +1008,7 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { if (desc != null) { // Check if this column is a Update Key. If so, add it to the WHERE clause - if (normalizedUpdateNames.contains(normalizedColName)) { + if (normalizedKeyColumnNames.contains(normalizedColName)) { if (whereFieldCount.getAndIncrement() > 0) { sqlBuilder.append(" AND "); @@ -1045,6 +1101,66 @@ public class PutDatabaseRecord extends AbstractSessionFactoryProcessor { return new SqlAndIncludedColumns(sqlBuilder.toString(), includedColumns); } + private void checkValuesForRequiredColumns(RecordSchema recordSchema, TableSchema tableSchema, DMLSettings settings) { + final Set normalizedFieldNames = getNormalizedColumnNames(recordSchema, settings.translateFieldNames); + + for (final String requiredColName : tableSchema.getRequiredColumnNames()) { + final String normalizedColName = normalizeColumnName(requiredColName, settings.translateFieldNames); + if (!normalizedFieldNames.contains(normalizedColName)) { + String missingColMessage = "Record does not have a value for the Required column '" + requiredColName + "'"; + if (settings.failUnmappedColumns) { + getLogger().error(missingColMessage); + throw new IllegalArgumentException(missingColMessage); + } else if (settings.warningUnmappedColumns) { + getLogger().warn(missingColMessage); + } + } + } + } + + private Set getUpdateKeyColumnNames(String tableName, String updateKeys, TableSchema tableSchema) throws SQLIntegrityConstraintViolationException { + final Set updateKeyColumnNames; + + if (updateKeys == null) { + updateKeyColumnNames = tableSchema.getPrimaryKeyColumnNames(); + } else { + updateKeyColumnNames = new HashSet<>(); + for (final String updateKey : updateKeys.split(",")) { + updateKeyColumnNames.add(updateKey.trim()); + } + } + + if (updateKeyColumnNames.isEmpty()) { + throw new SQLIntegrityConstraintViolationException("Table '" + tableName + "' does not have a Primary Key and no Update Keys were specified"); + } + + return updateKeyColumnNames; + } + + private Set normalizeKeyColumnNamesAndCheckForValues(RecordSchema recordSchema, String updateKeys, DMLSettings settings, Set updateKeyColumnNames) throws MalformedRecordException { + // Create a Set of all normalized Update Key names, and ensure that there is a field in the record + // for each of the Update Key fields. + final Set normalizedRecordFieldNames = getNormalizedColumnNames(recordSchema, settings.translateFieldNames); + + final Set normalizedKeyColumnNames = new HashSet<>(); + for (final String updateKeyColumnName : updateKeyColumnNames) { + final String normalizedKeyColumnName = normalizeColumnName(updateKeyColumnName, settings.translateFieldNames); + normalizedKeyColumnNames.add(normalizedKeyColumnName); + + if (!normalizedRecordFieldNames.contains(normalizedKeyColumnName)) { + String missingColMessage = "Record does not have a value for the " + (updateKeys == null ? "Primary" : "Update") + "Key column '" + updateKeyColumnName + "'"; + if (settings.failUnmappedColumns) { + getLogger().error(missingColMessage); + throw new MalformedRecordException(missingColMessage); + } else if (settings.warningUnmappedColumns) { + getLogger().warn(missingColMessage); + } + } + } + + return normalizedKeyColumnNames; + } + private static String normalizeColumnName(final String colName, final boolean translateColumnNames) { return colName == null ? null : (translateColumnNames ? colName.toUpperCase().replace("_", "") : colName); } diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/DatabaseAdapter.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/DatabaseAdapter.java index e1251c4767..40de0b825f 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/DatabaseAdapter.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/DatabaseAdapter.java @@ -16,6 +16,9 @@ */ package org.apache.nifi.processors.standard.db; +import java.util.Collection; +import java.util.List; + /** * Interface for RDBMS/JDBC-specific code. */ @@ -55,6 +58,30 @@ public interface DatabaseAdapter { return getSelectStatement(tableName, columnNames, whereClause, orderByClause, limit, offset); } + /** + * Tells whether this adapter supports UPSERT. + * + * @return true if UPSERT is supported, false otherwise + */ + default boolean supportsUpsert() { + return false; + } + + /** + * Returns an SQL UPSERT statement - i.e. UPDATE record or INSERT if id doesn't exist. + *

+ * There is no standard way of doing this so not all adapters support it - use together with {@link #supportsUpsert()}! + * + * @param table The name of the table in which to update/insert a record into. + * @param columnNames The name of the columns in the table to add values to. + * @param uniqueKeyColumnNames The name of the columns that form a unique key. + * @return A String containing the parameterized jdbc SQL statement. + * The order and number of parameters are the same as that of the provided column list. + */ + default String getUpsertStatement(String table, List columnNames, Collection uniqueKeyColumnNames) { + throw new UnsupportedOperationException("UPSERT is not supported for " + getName()); + } + /** *

Returns a bare identifier string by removing wrapping escape characters * from identifier strings such as table and column names.

diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/impl/PostgreSQLDatabaseAdapter.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/impl/PostgreSQLDatabaseAdapter.java new file mode 100644 index 0000000000..03def343fd --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/db/impl/PostgreSQLDatabaseAdapter.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.standard.db.impl; + +import com.google.common.base.Preconditions; +import org.apache.nifi.util.StringUtils; + +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; + +public class PostgreSQLDatabaseAdapter extends GenericDatabaseAdapter { + @Override + public String getName() { + return "PostgreSQL"; + } + + @Override + public String getDescription() { + return "Generates PostgreSQL compatible SQL"; + } + + @Override + public boolean supportsUpsert() { + return true; + } + + @Override + public String getUpsertStatement(String table, List columnNames, Collection uniqueKeyColumnNames) { + Preconditions.checkArgument(!StringUtils.isEmpty(table), "Table name cannot be null or blank"); + Preconditions.checkArgument(columnNames != null && !columnNames.isEmpty(), "Column names cannot be null or empty"); + Preconditions.checkArgument(uniqueKeyColumnNames != null && !uniqueKeyColumnNames.isEmpty(), "Key column names cannot be null or empty"); + + String columns = columnNames.stream() + .collect(Collectors.joining(", ")); + + String parameterizedInsertValues = columnNames.stream() + .map(__ -> "?") + .collect(Collectors.joining(", ")); + + String updateValues = columnNames.stream() + .map(columnName -> "EXCLUDED." + columnName) + .collect(Collectors.joining(", ")); + + String conflictClause = "(" + uniqueKeyColumnNames.stream().collect(Collectors.joining(", ")) + ")"; + + StringBuilder statementStringBuilder = new StringBuilder("INSERT INTO ") + .append(table) + .append("(").append(columns).append(")") + .append(" VALUES ") + .append("(").append(parameterizedInsertValues).append(")") + .append(" ON CONFLICT ") + .append(conflictClause) + .append(" DO UPDATE SET ") + .append("(").append(columns).append(")") + .append(" = ") + .append("(").append(updateValues).append(")"); + + return statementStringBuilder.toString(); + } +} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processors.standard.db.DatabaseAdapter b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processors.standard.db.DatabaseAdapter index 2f53cf7f87..f104782c5b 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processors.standard.db.DatabaseAdapter +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processors.standard.db.DatabaseAdapter @@ -17,4 +17,5 @@ org.apache.nifi.processors.standard.db.impl.OracleDatabaseAdapter org.apache.nifi.processors.standard.db.impl.Oracle12DatabaseAdapter org.apache.nifi.processors.standard.db.impl.MSSQLDatabaseAdapter org.apache.nifi.processors.standard.db.impl.MSSQL2008DatabaseAdapter -org.apache.nifi.processors.standard.db.impl.MySQLDatabaseAdapter \ No newline at end of file +org.apache.nifi.processors.standard.db.impl.MySQLDatabaseAdapter +org.apache.nifi.processors.standard.db.impl.PostgreSQLDatabaseAdapter \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/db/impl/TestPostgreSQLDatabaseAdapter.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/db/impl/TestPostgreSQLDatabaseAdapter.java new file mode 100644 index 0000000000..15fc95a6ca --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/db/impl/TestPostgreSQLDatabaseAdapter.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.standard.db.impl; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class TestPostgreSQLDatabaseAdapter { + private PostgreSQLDatabaseAdapter testSubject; + + @Before + public void setUp() throws Exception { + testSubject = new PostgreSQLDatabaseAdapter(); + } + + @Test + public void testSupportsUpsert() throws Exception { + assertTrue(testSubject.getClass().getSimpleName() + " should support upsert", testSubject.supportsUpsert()); + } + + @Test + public void testGetUpsertStatementWithNullTableName() throws Exception { + testGetUpsertStatement(null, Arrays.asList("notEmpty"), Arrays.asList("notEmpty"), new IllegalArgumentException("Table name cannot be null or blank")); + } + + @Test + public void testGetUpsertStatementWithBlankTableName() throws Exception { + testGetUpsertStatement("", Arrays.asList("notEmpty"), Arrays.asList("notEmpty"), new IllegalArgumentException("Table name cannot be null or blank")); + } + + @Test + public void testGetUpsertStatementWithNullColumnNames() throws Exception { + testGetUpsertStatement("notEmpty", null, Arrays.asList("notEmpty"), new IllegalArgumentException("Column names cannot be null or empty")); + } + + @Test + public void testGetUpsertStatementWithEmptyColumnNames() throws Exception { + testGetUpsertStatement("notEmpty", Collections.emptyList(), Arrays.asList("notEmpty"), new IllegalArgumentException("Column names cannot be null or empty")); + } + + @Test + public void testGetUpsertStatementWithNullKeyColumnNames() throws Exception { + testGetUpsertStatement("notEmpty", Arrays.asList("notEmpty"), null, new IllegalArgumentException("Key column names cannot be null or empty")); + } + + @Test + public void testGetUpsertStatementWithEmptyKeyColumnNames() throws Exception { + testGetUpsertStatement("notEmpty", Arrays.asList("notEmpty"), Collections.emptyList(), new IllegalArgumentException("Key column names cannot be null or empty")); + } + + @Test + public void testGetUpsertStatement() throws Exception { + // GIVEN + String tableName = "table"; + List columnNames = Arrays.asList("column1","column2", "column3", "column4"); + Collection uniqueKeyColumnNames = Arrays.asList("column2","column4"); + + String expected = "INSERT INTO" + + " table(column1, column2, column3, column4) VALUES (?, ?, ?, ?)" + + " ON CONFLICT (column2, column4)" + + " DO UPDATE SET" + + " (column1, column2, column3, column4) = (EXCLUDED.column1, EXCLUDED.column2, EXCLUDED.column3, EXCLUDED.column4)"; + + // WHEN + // THEN + testGetUpsertStatement(tableName, columnNames, uniqueKeyColumnNames, expected); + } + + private void testGetUpsertStatement(String tableName, List columnNames, Collection uniqueKeyColumnNames, IllegalArgumentException expected) { + try { + testGetUpsertStatement(tableName, columnNames, uniqueKeyColumnNames, (String)null); + fail(); + } catch (IllegalArgumentException e) { + assertEquals(expected.getMessage(), e.getMessage()); + } + } + + private void testGetUpsertStatement(String tableName, List columnNames, Collection uniqueKeyColumnNames, String expected) { + // WHEN + String actual = testSubject.getUpsertStatement(tableName, columnNames, uniqueKeyColumnNames); + + // THEN + assertEquals(expected, actual); + } +}