mirror of https://github.com/apache/nifi.git
NIFI-3755 - Restoring Hive exception handling behavior
This closes #1711. Signed-off-by: Bryan Bende <bbende@apache.org>
This commit is contained in:
parent
aa4efb43ca
commit
0054a9e35f
|
@ -44,6 +44,7 @@ import org.apache.nifi.util.hive.HiveUtils;
|
|||
import org.apache.nifi.util.hive.ValidationResources;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.UndeclaredThrowableException;
|
||||
import java.security.PrivilegedExceptionAction;
|
||||
import java.sql.Connection;
|
||||
import java.sql.SQLException;
|
||||
|
@ -278,13 +279,16 @@ public class HiveConnectionPool extends AbstractControllerService implements Hiv
|
|||
public Connection getConnection() throws ProcessException {
|
||||
try {
|
||||
if (ugi != null) {
|
||||
return ugi.doAs(new PrivilegedExceptionAction<Connection>() {
|
||||
@Override
|
||||
public Connection run() throws Exception {
|
||||
return dataSource.getConnection();
|
||||
try {
|
||||
return ugi.doAs((PrivilegedExceptionAction<Connection>) () -> dataSource.getConnection());
|
||||
} catch (UndeclaredThrowableException e) {
|
||||
Throwable cause = e.getCause();
|
||||
if (cause instanceof SQLException) {
|
||||
throw (SQLException) cause;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
} else {
|
||||
getLogger().info("Simple Authentication");
|
||||
return dataSource.getConnection();
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
package org.apache.nifi.util.hive;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.UndeclaredThrowableException;
|
||||
import java.security.PrivilegedExceptionAction;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
@ -83,7 +84,16 @@ public class HiveWriter {
|
|||
if (ugi == null) {
|
||||
return new StrictJsonWriter(endPoint, hiveConf);
|
||||
} else {
|
||||
return ugi.doAs((PrivilegedExceptionAction<StrictJsonWriter>) () -> new StrictJsonWriter(endPoint, hiveConf));
|
||||
try {
|
||||
return ugi.doAs((PrivilegedExceptionAction<StrictJsonWriter>) () -> new StrictJsonWriter(endPoint, hiveConf));
|
||||
} catch (UndeclaredThrowableException e) {
|
||||
Throwable cause = e.getCause();
|
||||
if (cause instanceof StreamingException) {
|
||||
throw (StreamingException) cause;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -354,7 +364,16 @@ public class HiveWriter {
|
|||
if (ugi == null) {
|
||||
return callRunner.call();
|
||||
}
|
||||
return ugi.doAs((PrivilegedExceptionAction<T>) () -> callRunner.call());
|
||||
try {
|
||||
return ugi.doAs((PrivilegedExceptionAction<T>) () -> callRunner.call());
|
||||
} catch (UndeclaredThrowableException e) {
|
||||
Throwable cause = e.getCause();
|
||||
// Unwrap exception so it is thrown the same way as without ugi
|
||||
if (!(cause instanceof Exception)) {
|
||||
throw e;
|
||||
}
|
||||
throw (Exception)cause;
|
||||
}
|
||||
});
|
||||
try {
|
||||
if (callTimeout > 0) {
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.nifi.dbcp.hive;
|
||||
|
||||
import org.apache.commons.dbcp.BasicDataSource;
|
||||
import org.apache.hadoop.security.UserGroupInformation;
|
||||
import org.apache.nifi.controller.AbstractControllerService;
|
||||
import org.apache.nifi.logging.ComponentLog;
|
||||
import org.apache.nifi.processor.exception.ProcessException;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.UndeclaredThrowableException;
|
||||
import java.security.PrivilegedExceptionAction;
|
||||
import java.sql.SQLException;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.mockito.Matchers.isA;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class HiveConnectionPoolTest {
|
||||
private UserGroupInformation userGroupInformation;
|
||||
private HiveConnectionPool hiveConnectionPool;
|
||||
private BasicDataSource basicDataSource;
|
||||
private ComponentLog componentLog;
|
||||
|
||||
@Before
|
||||
public void setup() throws Exception {
|
||||
userGroupInformation = mock(UserGroupInformation.class);
|
||||
basicDataSource = mock(BasicDataSource.class);
|
||||
componentLog = mock(ComponentLog.class);
|
||||
|
||||
when(userGroupInformation.doAs(isA(PrivilegedExceptionAction.class))).thenAnswer(invocation -> {
|
||||
try {
|
||||
return ((PrivilegedExceptionAction) invocation.getArguments()[0]).run();
|
||||
} catch (IOException |Error|RuntimeException|InterruptedException e) {
|
||||
throw e;
|
||||
} catch (Throwable e) {
|
||||
throw new UndeclaredThrowableException(e);
|
||||
}
|
||||
});
|
||||
initPool();
|
||||
}
|
||||
|
||||
private void initPool() throws Exception {
|
||||
hiveConnectionPool = new HiveConnectionPool();
|
||||
|
||||
Field ugiField = HiveConnectionPool.class.getDeclaredField("ugi");
|
||||
ugiField.setAccessible(true);
|
||||
ugiField.set(hiveConnectionPool, userGroupInformation);
|
||||
|
||||
Field dataSourceField = HiveConnectionPool.class.getDeclaredField("dataSource");
|
||||
dataSourceField.setAccessible(true);
|
||||
dataSourceField.set(hiveConnectionPool, basicDataSource);
|
||||
|
||||
Field componentLogField = AbstractControllerService.class.getDeclaredField("logger");
|
||||
componentLogField.setAccessible(true);
|
||||
componentLogField.set(hiveConnectionPool, componentLog);
|
||||
}
|
||||
|
||||
@Test(expected = ProcessException.class)
|
||||
public void testGetConnectionSqlException() throws SQLException {
|
||||
SQLException sqlException = new SQLException("bad sql");
|
||||
when(basicDataSource.getConnection()).thenThrow(sqlException);
|
||||
try {
|
||||
hiveConnectionPool.getConnection();
|
||||
} catch (ProcessException e) {
|
||||
assertEquals(sqlException, e.getCause());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.nifi.util.hive;
|
||||
|
||||
import com.google.common.util.concurrent.UncheckedExecutionException;
|
||||
import org.apache.hadoop.hive.conf.HiveConf;
|
||||
import org.apache.hadoop.security.UserGroupInformation;
|
||||
import org.apache.hive.hcatalog.streaming.HiveEndPoint;
|
||||
import org.apache.hive.hcatalog.streaming.InvalidTable;
|
||||
import org.apache.hive.hcatalog.streaming.RecordWriter;
|
||||
import org.apache.hive.hcatalog.streaming.StreamingConnection;
|
||||
import org.apache.hive.hcatalog.streaming.StreamingException;
|
||||
import org.apache.hive.hcatalog.streaming.TransactionBatch;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.mockito.stubbing.Answer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.UndeclaredThrowableException;
|
||||
import java.security.PrivilegedExceptionAction;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.anyLong;
|
||||
import static org.mockito.Matchers.isA;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class HiveWriterTest {
|
||||
private HiveEndPoint hiveEndPoint;
|
||||
private int txnsPerBatch;
|
||||
private boolean autoCreatePartitions;
|
||||
private int callTimeout;
|
||||
private ExecutorService executorService;
|
||||
private UserGroupInformation userGroupInformation;
|
||||
private HiveConf hiveConf;
|
||||
private HiveWriter hiveWriter;
|
||||
private StreamingConnection streamingConnection;
|
||||
private RecordWriter recordWriter;
|
||||
private Callable<RecordWriter> recordWriterCallable;
|
||||
private TransactionBatch transactionBatch;
|
||||
|
||||
@Before
|
||||
public void setup() throws Exception {
|
||||
hiveEndPoint = mock(HiveEndPoint.class);
|
||||
txnsPerBatch = 100;
|
||||
autoCreatePartitions = true;
|
||||
callTimeout = 0;
|
||||
executorService = mock(ExecutorService.class);
|
||||
streamingConnection = mock(StreamingConnection.class);
|
||||
transactionBatch = mock(TransactionBatch.class);
|
||||
userGroupInformation = mock(UserGroupInformation.class);
|
||||
hiveConf = mock(HiveConf.class);
|
||||
recordWriter = mock(RecordWriter.class);
|
||||
recordWriterCallable = mock(Callable.class);
|
||||
when(recordWriterCallable.call()).thenReturn(recordWriter);
|
||||
|
||||
when(hiveEndPoint.newConnection(autoCreatePartitions, hiveConf, userGroupInformation)).thenReturn(streamingConnection);
|
||||
when(streamingConnection.fetchTransactionBatch(txnsPerBatch, recordWriter)).thenReturn(transactionBatch);
|
||||
when(executorService.submit(isA(Callable.class))).thenAnswer(invocation -> {
|
||||
Future future = mock(Future.class);
|
||||
Answer<Object> answer = i -> ((Callable) invocation.getArguments()[0]).call();
|
||||
when(future.get()).thenAnswer(answer);
|
||||
when(future.get(anyLong(), any(TimeUnit.class))).thenAnswer(answer);
|
||||
return future;
|
||||
});
|
||||
when(userGroupInformation.doAs(isA(PrivilegedExceptionAction.class))).thenAnswer(invocation -> {
|
||||
try {
|
||||
try {
|
||||
return ((PrivilegedExceptionAction) invocation.getArguments()[0]).run();
|
||||
} catch (UncheckedExecutionException e) {
|
||||
// Creation of strict json writer will fail due to external deps, this gives us chance to catch it
|
||||
for (StackTraceElement stackTraceElement : e.getStackTrace()) {
|
||||
if (stackTraceElement.toString().startsWith("org.apache.hive.hcatalog.streaming.StrictJsonWriter.<init>(")) {
|
||||
return recordWriterCallable.call();
|
||||
}
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
} catch (IOException | Error | RuntimeException | InterruptedException e) {
|
||||
throw e;
|
||||
} catch (Throwable e) {
|
||||
throw new UndeclaredThrowableException(e);
|
||||
}
|
||||
});
|
||||
|
||||
initWriter();
|
||||
}
|
||||
|
||||
private void initWriter() throws Exception {
|
||||
hiveWriter = new HiveWriter(hiveEndPoint, txnsPerBatch, autoCreatePartitions, callTimeout, executorService, userGroupInformation, hiveConf);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNormal() {
|
||||
assertNotNull(hiveWriter);
|
||||
}
|
||||
|
||||
@Test(expected = HiveWriter.ConnectFailure.class)
|
||||
public void testNewConnectionInvalidTable() throws Exception {
|
||||
hiveEndPoint = mock(HiveEndPoint.class);
|
||||
InvalidTable invalidTable = new InvalidTable("badDb", "badTable");
|
||||
when(hiveEndPoint.newConnection(autoCreatePartitions, hiveConf, userGroupInformation)).thenThrow(invalidTable);
|
||||
try {
|
||||
initWriter();
|
||||
} catch (HiveWriter.ConnectFailure e) {
|
||||
assertEquals(invalidTable, e.getCause());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = HiveWriter.ConnectFailure.class)
|
||||
public void testRecordWriterStreamingException() throws Exception {
|
||||
recordWriterCallable = mock(Callable.class);
|
||||
StreamingException streamingException = new StreamingException("Test Exception");
|
||||
when(recordWriterCallable.call()).thenThrow(streamingException);
|
||||
try {
|
||||
initWriter();
|
||||
} catch (HiveWriter.ConnectFailure e) {
|
||||
assertEquals(streamingException, e.getCause());
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue