diff --git a/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/main/java/org/apache/nifi/dbcp/hive/HiveConnectionPool.java b/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/main/java/org/apache/nifi/dbcp/hive/HiveConnectionPool.java index 5d806102b0..64f3027749 100644 --- a/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/main/java/org/apache/nifi/dbcp/hive/HiveConnectionPool.java +++ b/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/main/java/org/apache/nifi/dbcp/hive/HiveConnectionPool.java @@ -44,6 +44,7 @@ import org.apache.nifi.util.hive.HiveUtils; import org.apache.nifi.util.hive.ValidationResources; import java.io.IOException; +import java.lang.reflect.UndeclaredThrowableException; import java.security.PrivilegedExceptionAction; import java.sql.Connection; import java.sql.SQLException; @@ -278,13 +279,16 @@ public class HiveConnectionPool extends AbstractControllerService implements Hiv public Connection getConnection() throws ProcessException { try { if (ugi != null) { - return ugi.doAs(new PrivilegedExceptionAction() { - @Override - public Connection run() throws Exception { - return dataSource.getConnection(); + try { + return ugi.doAs((PrivilegedExceptionAction) () -> dataSource.getConnection()); + } catch (UndeclaredThrowableException e) { + Throwable cause = e.getCause(); + if (cause instanceof SQLException) { + throw (SQLException) cause; + } else { + throw e; } - }); - + } } else { getLogger().info("Simple Authentication"); return dataSource.getConnection(); diff --git a/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/main/java/org/apache/nifi/util/hive/HiveWriter.java b/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/main/java/org/apache/nifi/util/hive/HiveWriter.java index 1cf77a8564..0c121e050a 100644 --- a/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/main/java/org/apache/nifi/util/hive/HiveWriter.java +++ b/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/main/java/org/apache/nifi/util/hive/HiveWriter.java @@ -19,6 +19,7 @@ package org.apache.nifi.util.hive; import java.io.IOException; +import java.lang.reflect.UndeclaredThrowableException; import java.security.PrivilegedExceptionAction; import java.util.List; import java.util.concurrent.ExecutionException; @@ -83,7 +84,16 @@ public class HiveWriter { if (ugi == null) { return new StrictJsonWriter(endPoint, hiveConf); } else { - return ugi.doAs((PrivilegedExceptionAction) () -> new StrictJsonWriter(endPoint, hiveConf)); + try { + return ugi.doAs((PrivilegedExceptionAction) () -> new StrictJsonWriter(endPoint, hiveConf)); + } catch (UndeclaredThrowableException e) { + Throwable cause = e.getCause(); + if (cause instanceof StreamingException) { + throw (StreamingException) cause; + } else { + throw e; + } + } } } @@ -354,7 +364,16 @@ public class HiveWriter { if (ugi == null) { return callRunner.call(); } - return ugi.doAs((PrivilegedExceptionAction) () -> callRunner.call()); + try { + return ugi.doAs((PrivilegedExceptionAction) () -> callRunner.call()); + } catch (UndeclaredThrowableException e) { + Throwable cause = e.getCause(); + // Unwrap exception so it is thrown the same way as without ugi + if (!(cause instanceof Exception)) { + throw e; + } + throw (Exception)cause; + } }); try { if (callTimeout > 0) { diff --git a/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/test/java/org/apache/nifi/dbcp/hive/HiveConnectionPoolTest.java b/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/test/java/org/apache/nifi/dbcp/hive/HiveConnectionPoolTest.java new file mode 100644 index 0000000000..0b5cd8f64a --- /dev/null +++ b/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/test/java/org/apache/nifi/dbcp/hive/HiveConnectionPoolTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.dbcp.hive; + +import org.apache.commons.dbcp.BasicDataSource; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.nifi.controller.AbstractControllerService; +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.exception.ProcessException; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.UndeclaredThrowableException; +import java.security.PrivilegedExceptionAction; +import java.sql.SQLException; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HiveConnectionPoolTest { + private UserGroupInformation userGroupInformation; + private HiveConnectionPool hiveConnectionPool; + private BasicDataSource basicDataSource; + private ComponentLog componentLog; + + @Before + public void setup() throws Exception { + userGroupInformation = mock(UserGroupInformation.class); + basicDataSource = mock(BasicDataSource.class); + componentLog = mock(ComponentLog.class); + + when(userGroupInformation.doAs(isA(PrivilegedExceptionAction.class))).thenAnswer(invocation -> { + try { + return ((PrivilegedExceptionAction) invocation.getArguments()[0]).run(); + } catch (IOException |Error|RuntimeException|InterruptedException e) { + throw e; + } catch (Throwable e) { + throw new UndeclaredThrowableException(e); + } + }); + initPool(); + } + + private void initPool() throws Exception { + hiveConnectionPool = new HiveConnectionPool(); + + Field ugiField = HiveConnectionPool.class.getDeclaredField("ugi"); + ugiField.setAccessible(true); + ugiField.set(hiveConnectionPool, userGroupInformation); + + Field dataSourceField = HiveConnectionPool.class.getDeclaredField("dataSource"); + dataSourceField.setAccessible(true); + dataSourceField.set(hiveConnectionPool, basicDataSource); + + Field componentLogField = AbstractControllerService.class.getDeclaredField("logger"); + componentLogField.setAccessible(true); + componentLogField.set(hiveConnectionPool, componentLog); + } + + @Test(expected = ProcessException.class) + public void testGetConnectionSqlException() throws SQLException { + SQLException sqlException = new SQLException("bad sql"); + when(basicDataSource.getConnection()).thenThrow(sqlException); + try { + hiveConnectionPool.getConnection(); + } catch (ProcessException e) { + assertEquals(sqlException, e.getCause()); + throw e; + } + } +} diff --git a/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/test/java/org/apache/nifi/util/hive/HiveWriterTest.java b/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/test/java/org/apache/nifi/util/hive/HiveWriterTest.java new file mode 100644 index 0000000000..e35f487592 --- /dev/null +++ b/nifi-nar-bundles/nifi-hive-bundle/nifi-hive-processors/src/test/java/org/apache/nifi/util/hive/HiveWriterTest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.util.hive; + +import com.google.common.util.concurrent.UncheckedExecutionException; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hive.hcatalog.streaming.HiveEndPoint; +import org.apache.hive.hcatalog.streaming.InvalidTable; +import org.apache.hive.hcatalog.streaming.RecordWriter; +import org.apache.hive.hcatalog.streaming.StreamingConnection; +import org.apache.hive.hcatalog.streaming.StreamingException; +import org.apache.hive.hcatalog.streaming.TransactionBatch; +import org.junit.Before; +import org.junit.Test; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.lang.reflect.UndeclaredThrowableException; +import java.security.PrivilegedExceptionAction; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HiveWriterTest { + private HiveEndPoint hiveEndPoint; + private int txnsPerBatch; + private boolean autoCreatePartitions; + private int callTimeout; + private ExecutorService executorService; + private UserGroupInformation userGroupInformation; + private HiveConf hiveConf; + private HiveWriter hiveWriter; + private StreamingConnection streamingConnection; + private RecordWriter recordWriter; + private Callable recordWriterCallable; + private TransactionBatch transactionBatch; + + @Before + public void setup() throws Exception { + hiveEndPoint = mock(HiveEndPoint.class); + txnsPerBatch = 100; + autoCreatePartitions = true; + callTimeout = 0; + executorService = mock(ExecutorService.class); + streamingConnection = mock(StreamingConnection.class); + transactionBatch = mock(TransactionBatch.class); + userGroupInformation = mock(UserGroupInformation.class); + hiveConf = mock(HiveConf.class); + recordWriter = mock(RecordWriter.class); + recordWriterCallable = mock(Callable.class); + when(recordWriterCallable.call()).thenReturn(recordWriter); + + when(hiveEndPoint.newConnection(autoCreatePartitions, hiveConf, userGroupInformation)).thenReturn(streamingConnection); + when(streamingConnection.fetchTransactionBatch(txnsPerBatch, recordWriter)).thenReturn(transactionBatch); + when(executorService.submit(isA(Callable.class))).thenAnswer(invocation -> { + Future future = mock(Future.class); + Answer answer = i -> ((Callable) invocation.getArguments()[0]).call(); + when(future.get()).thenAnswer(answer); + when(future.get(anyLong(), any(TimeUnit.class))).thenAnswer(answer); + return future; + }); + when(userGroupInformation.doAs(isA(PrivilegedExceptionAction.class))).thenAnswer(invocation -> { + try { + try { + return ((PrivilegedExceptionAction) invocation.getArguments()[0]).run(); + } catch (UncheckedExecutionException e) { + // Creation of strict json writer will fail due to external deps, this gives us chance to catch it + for (StackTraceElement stackTraceElement : e.getStackTrace()) { + if (stackTraceElement.toString().startsWith("org.apache.hive.hcatalog.streaming.StrictJsonWriter.(")) { + return recordWriterCallable.call(); + } + } + throw e; + } + } catch (IOException | Error | RuntimeException | InterruptedException e) { + throw e; + } catch (Throwable e) { + throw new UndeclaredThrowableException(e); + } + }); + + initWriter(); + } + + private void initWriter() throws Exception { + hiveWriter = new HiveWriter(hiveEndPoint, txnsPerBatch, autoCreatePartitions, callTimeout, executorService, userGroupInformation, hiveConf); + } + + @Test + public void testNormal() { + assertNotNull(hiveWriter); + } + + @Test(expected = HiveWriter.ConnectFailure.class) + public void testNewConnectionInvalidTable() throws Exception { + hiveEndPoint = mock(HiveEndPoint.class); + InvalidTable invalidTable = new InvalidTable("badDb", "badTable"); + when(hiveEndPoint.newConnection(autoCreatePartitions, hiveConf, userGroupInformation)).thenThrow(invalidTable); + try { + initWriter(); + } catch (HiveWriter.ConnectFailure e) { + assertEquals(invalidTable, e.getCause()); + throw e; + } + } + + @Test(expected = HiveWriter.ConnectFailure.class) + public void testRecordWriterStreamingException() throws Exception { + recordWriterCallable = mock(Callable.class); + StreamingException streamingException = new StreamingException("Test Exception"); + when(recordWriterCallable.call()).thenThrow(streamingException); + try { + initWriter(); + } catch (HiveWriter.ConnectFailure e) { + assertEquals(streamingException, e.getCause()); + throw e; + } + } +}