NIFI-3755 - Restoring Hive exception handling behavior

This closes #1711.

Signed-off-by: Bryan Bende <bbende@apache.org>
This commit is contained in:
Bryan Rosander 2017-04-27 15:37:27 -04:00 committed by Bryan Bende
parent aa4efb43ca
commit 0054a9e35f
No known key found for this signature in database
GPG Key ID: A0DDA9ED50711C39
4 changed files with 265 additions and 8 deletions

View File

@ -44,6 +44,7 @@ import org.apache.nifi.util.hive.HiveUtils;
import org.apache.nifi.util.hive.ValidationResources; import org.apache.nifi.util.hive.ValidationResources;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.UndeclaredThrowableException;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.sql.Connection; import java.sql.Connection;
import java.sql.SQLException; import java.sql.SQLException;
@ -278,13 +279,16 @@ public class HiveConnectionPool extends AbstractControllerService implements Hiv
public Connection getConnection() throws ProcessException { public Connection getConnection() throws ProcessException {
try { try {
if (ugi != null) { if (ugi != null) {
return ugi.doAs(new PrivilegedExceptionAction<Connection>() { try {
@Override return ugi.doAs((PrivilegedExceptionAction<Connection>) () -> dataSource.getConnection());
public Connection run() throws Exception { } catch (UndeclaredThrowableException e) {
return dataSource.getConnection(); Throwable cause = e.getCause();
if (cause instanceof SQLException) {
throw (SQLException) cause;
} else {
throw e;
}
} }
});
} else { } else {
getLogger().info("Simple Authentication"); getLogger().info("Simple Authentication");
return dataSource.getConnection(); return dataSource.getConnection();

View File

@ -19,6 +19,7 @@
package org.apache.nifi.util.hive; package org.apache.nifi.util.hive;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.UndeclaredThrowableException;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -83,7 +84,16 @@ public class HiveWriter {
if (ugi == null) { if (ugi == null) {
return new StrictJsonWriter(endPoint, hiveConf); return new StrictJsonWriter(endPoint, hiveConf);
} else { } else {
try {
return ugi.doAs((PrivilegedExceptionAction<StrictJsonWriter>) () -> new StrictJsonWriter(endPoint, hiveConf)); return ugi.doAs((PrivilegedExceptionAction<StrictJsonWriter>) () -> new StrictJsonWriter(endPoint, hiveConf));
} catch (UndeclaredThrowableException e) {
Throwable cause = e.getCause();
if (cause instanceof StreamingException) {
throw (StreamingException) cause;
} else {
throw e;
}
}
} }
} }
@ -354,7 +364,16 @@ public class HiveWriter {
if (ugi == null) { if (ugi == null) {
return callRunner.call(); return callRunner.call();
} }
try {
return ugi.doAs((PrivilegedExceptionAction<T>) () -> callRunner.call()); return ugi.doAs((PrivilegedExceptionAction<T>) () -> callRunner.call());
} catch (UndeclaredThrowableException e) {
Throwable cause = e.getCause();
// Unwrap exception so it is thrown the same way as without ugi
if (!(cause instanceof Exception)) {
throw e;
}
throw (Exception)cause;
}
}); });
try { try {
if (callTimeout > 0) { if (callTimeout > 0) {

View File

@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.dbcp.hive;
import org.apache.commons.dbcp.BasicDataSource;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.nifi.controller.AbstractControllerService;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.exception.ProcessException;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.UndeclaredThrowableException;
import java.security.PrivilegedExceptionAction;
import java.sql.SQLException;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class HiveConnectionPoolTest {
private UserGroupInformation userGroupInformation;
private HiveConnectionPool hiveConnectionPool;
private BasicDataSource basicDataSource;
private ComponentLog componentLog;
@Before
public void setup() throws Exception {
userGroupInformation = mock(UserGroupInformation.class);
basicDataSource = mock(BasicDataSource.class);
componentLog = mock(ComponentLog.class);
when(userGroupInformation.doAs(isA(PrivilegedExceptionAction.class))).thenAnswer(invocation -> {
try {
return ((PrivilegedExceptionAction) invocation.getArguments()[0]).run();
} catch (IOException |Error|RuntimeException|InterruptedException e) {
throw e;
} catch (Throwable e) {
throw new UndeclaredThrowableException(e);
}
});
initPool();
}
private void initPool() throws Exception {
hiveConnectionPool = new HiveConnectionPool();
Field ugiField = HiveConnectionPool.class.getDeclaredField("ugi");
ugiField.setAccessible(true);
ugiField.set(hiveConnectionPool, userGroupInformation);
Field dataSourceField = HiveConnectionPool.class.getDeclaredField("dataSource");
dataSourceField.setAccessible(true);
dataSourceField.set(hiveConnectionPool, basicDataSource);
Field componentLogField = AbstractControllerService.class.getDeclaredField("logger");
componentLogField.setAccessible(true);
componentLogField.set(hiveConnectionPool, componentLog);
}
@Test(expected = ProcessException.class)
public void testGetConnectionSqlException() throws SQLException {
SQLException sqlException = new SQLException("bad sql");
when(basicDataSource.getConnection()).thenThrow(sqlException);
try {
hiveConnectionPool.getConnection();
} catch (ProcessException e) {
assertEquals(sqlException, e.getCause());
throw e;
}
}
}

View File

@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.util.hive;
import com.google.common.util.concurrent.UncheckedExecutionException;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hive.hcatalog.streaming.HiveEndPoint;
import org.apache.hive.hcatalog.streaming.InvalidTable;
import org.apache.hive.hcatalog.streaming.RecordWriter;
import org.apache.hive.hcatalog.streaming.StreamingConnection;
import org.apache.hive.hcatalog.streaming.StreamingException;
import org.apache.hive.hcatalog.streaming.TransactionBatch;
import org.junit.Before;
import org.junit.Test;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.lang.reflect.UndeclaredThrowableException;
import java.security.PrivilegedExceptionAction;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class HiveWriterTest {
private HiveEndPoint hiveEndPoint;
private int txnsPerBatch;
private boolean autoCreatePartitions;
private int callTimeout;
private ExecutorService executorService;
private UserGroupInformation userGroupInformation;
private HiveConf hiveConf;
private HiveWriter hiveWriter;
private StreamingConnection streamingConnection;
private RecordWriter recordWriter;
private Callable<RecordWriter> recordWriterCallable;
private TransactionBatch transactionBatch;
@Before
public void setup() throws Exception {
hiveEndPoint = mock(HiveEndPoint.class);
txnsPerBatch = 100;
autoCreatePartitions = true;
callTimeout = 0;
executorService = mock(ExecutorService.class);
streamingConnection = mock(StreamingConnection.class);
transactionBatch = mock(TransactionBatch.class);
userGroupInformation = mock(UserGroupInformation.class);
hiveConf = mock(HiveConf.class);
recordWriter = mock(RecordWriter.class);
recordWriterCallable = mock(Callable.class);
when(recordWriterCallable.call()).thenReturn(recordWriter);
when(hiveEndPoint.newConnection(autoCreatePartitions, hiveConf, userGroupInformation)).thenReturn(streamingConnection);
when(streamingConnection.fetchTransactionBatch(txnsPerBatch, recordWriter)).thenReturn(transactionBatch);
when(executorService.submit(isA(Callable.class))).thenAnswer(invocation -> {
Future future = mock(Future.class);
Answer<Object> answer = i -> ((Callable) invocation.getArguments()[0]).call();
when(future.get()).thenAnswer(answer);
when(future.get(anyLong(), any(TimeUnit.class))).thenAnswer(answer);
return future;
});
when(userGroupInformation.doAs(isA(PrivilegedExceptionAction.class))).thenAnswer(invocation -> {
try {
try {
return ((PrivilegedExceptionAction) invocation.getArguments()[0]).run();
} catch (UncheckedExecutionException e) {
// Creation of strict json writer will fail due to external deps, this gives us chance to catch it
for (StackTraceElement stackTraceElement : e.getStackTrace()) {
if (stackTraceElement.toString().startsWith("org.apache.hive.hcatalog.streaming.StrictJsonWriter.<init>(")) {
return recordWriterCallable.call();
}
}
throw e;
}
} catch (IOException | Error | RuntimeException | InterruptedException e) {
throw e;
} catch (Throwable e) {
throw new UndeclaredThrowableException(e);
}
});
initWriter();
}
private void initWriter() throws Exception {
hiveWriter = new HiveWriter(hiveEndPoint, txnsPerBatch, autoCreatePartitions, callTimeout, executorService, userGroupInformation, hiveConf);
}
@Test
public void testNormal() {
assertNotNull(hiveWriter);
}
@Test(expected = HiveWriter.ConnectFailure.class)
public void testNewConnectionInvalidTable() throws Exception {
hiveEndPoint = mock(HiveEndPoint.class);
InvalidTable invalidTable = new InvalidTable("badDb", "badTable");
when(hiveEndPoint.newConnection(autoCreatePartitions, hiveConf, userGroupInformation)).thenThrow(invalidTable);
try {
initWriter();
} catch (HiveWriter.ConnectFailure e) {
assertEquals(invalidTable, e.getCause());
throw e;
}
}
@Test(expected = HiveWriter.ConnectFailure.class)
public void testRecordWriterStreamingException() throws Exception {
recordWriterCallable = mock(Callable.class);
StreamingException streamingException = new StreamingException("Test Exception");
when(recordWriterCallable.call()).thenThrow(streamingException);
try {
initWriter();
} catch (HiveWriter.ConnectFailure e) {
assertEquals(streamingException, e.getCause());
throw e;
}
}
}