HBASE-12953 RegionServer is not functionally working with AysncRpcClient in secure mode

Signed-off-by: stack <stack@apache.org>
This commit is contained in:
zhangduo 2015-02-18 09:46:27 +08:00 committed by stack
parent e405017a31
commit b20675f5af
5 changed files with 149 additions and 55 deletions

View File

@ -189,10 +189,11 @@ public class AsyncRpcChannel {
if (ticket == null) {
throw new FatalConnectionException("ticket/user is null");
}
final UserGroupInformation realTicket = ticket;
saslHandler = ticket.doAs(new PrivilegedExceptionAction<SaslClientHandler>() {
@Override
public SaslClientHandler run() throws IOException {
return getSaslHandler(bootstrap);
return getSaslHandler(realTicket, bootstrap);
}
});
if (saslHandler != null) {
@ -244,20 +245,21 @@ public class AsyncRpcChannel {
/**
* Get SASL handler
*
* @param bootstrap to reconnect to
* @return new SASL handler
* @throws java.io.IOException if handler failed to create
*/
private SaslClientHandler getSaslHandler(final Bootstrap bootstrap) throws IOException {
return new SaslClientHandler(authMethod, token, serverPrincipal, client.fallbackAllowed,
client.conf.get("hbase.rpc.protection", SaslUtil.QualityOfProtection.AUTHENTICATION.name()
.toLowerCase()), new SaslClientHandler.SaslExceptionHandler() {
private SaslClientHandler getSaslHandler(final UserGroupInformation realTicket,
final Bootstrap bootstrap) throws IOException {
return new SaslClientHandler(realTicket, authMethod, token, serverPrincipal,
client.fallbackAllowed, client.conf.get("hbase.rpc.protection",
SaslUtil.QualityOfProtection.AUTHENTICATION.name().toLowerCase()),
new SaslClientHandler.SaslExceptionHandler() {
@Override
public void handle(int retryCount, Random random, Throwable cause) {
try {
// Handle Sasl failure. Try to potentially get new credentials
handleSaslConnectionFailure(retryCount, cause, ticket.getUGI());
handleSaslConnectionFailure(retryCount, cause, realTicket);
// Try to reconnect
AsyncRpcClient.WHEEL_TIMER.newTimeout(new TimerTask() {

View File

@ -24,10 +24,12 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hbase.classification.InterfaceAudience;
import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
@ -35,8 +37,10 @@ import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import java.io.IOException;
import java.nio.charset.Charset;
import java.security.PrivilegedExceptionAction;
import java.util.Random;
/**
@ -48,6 +52,8 @@ public class SaslClientHandler extends ChannelDuplexHandler {
private final boolean fallbackAllowed;
private final UserGroupInformation ticket;
/**
* Used for client or server's token to send or receive from each other.
*/
@ -63,6 +69,7 @@ public class SaslClientHandler extends ChannelDuplexHandler {
/**
* Constructor
*
* @param ticket the ugi
* @param method auth method
* @param token for Sasl
* @param serverPrincipal Server's Kerberos principal name
@ -72,10 +79,11 @@ public class SaslClientHandler extends ChannelDuplexHandler {
* @param successfulConnectHandler handler for succesful connects
* @throws java.io.IOException if handler could not be created
*/
public SaslClientHandler(AuthMethod method, Token<? extends TokenIdentifier> token,
String serverPrincipal, boolean fallbackAllowed, String rpcProtection,
SaslExceptionHandler exceptionHandler, SaslSuccessfulConnectHandler successfulConnectHandler)
throws IOException {
public SaslClientHandler(UserGroupInformation ticket, AuthMethod method,
Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed,
String rpcProtection, SaslExceptionHandler exceptionHandler,
SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException {
this.ticket = ticket;
this.fallbackAllowed = fallbackAllowed;
this.exceptionHandler = exceptionHandler;
@ -109,9 +117,10 @@ public class SaslClientHandler extends ChannelDuplexHandler {
default:
throw new IOException("Unknown authentication method " + method);
}
if (saslClient == null)
if (saslClient == null) {
throw new IOException("Unable to find SASL client implementation");
}
}
/**
* Create a Digest Sasl client
@ -144,14 +153,26 @@ public class SaslClientHandler extends ChannelDuplexHandler {
null);
}
@Override public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
saslClient.dispose();
}
@Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.saslToken = new byte[0];
private byte[] evaluateChallenge(final byte[] challenge) throws Exception {
return ticket.doAs(new PrivilegedExceptionAction<byte[]>() {
@Override
public byte[] run() throws Exception {
return saslClient.evaluateChallenge(challenge);
}
});
}
@Override
public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
saslToken = new byte[0];
if (saslClient.hasInitialResponse()) {
saslToken = saslClient.evaluateChallenge(saslToken);
saslToken = evaluateChallenge(saslToken);
}
if (saslToken != null) {
writeSaslToken(ctx, saslToken);
@ -161,7 +182,8 @@ public class SaslClientHandler extends ChannelDuplexHandler {
}
}
@Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf in = (ByteBuf) msg;
// If not complete, try to negotiate
@ -187,15 +209,17 @@ public class SaslClientHandler extends ChannelDuplexHandler {
}
}
saslToken = new byte[len];
if (LOG.isDebugEnabled())
if (LOG.isDebugEnabled()) {
LOG.debug("Will read input token of size " + saslToken.length
+ " for processing by initSASLContext");
}
in.readBytes(saslToken);
saslToken = saslClient.evaluateChallenge(saslToken);
saslToken = evaluateChallenge(saslToken);
if (saslToken != null) {
if (LOG.isDebugEnabled())
if (LOG.isDebugEnabled()) {
LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
}
writeSaslToken(ctx, saslToken);
}
}
@ -246,7 +270,6 @@ public class SaslClientHandler extends ChannelDuplexHandler {
/**
* Write SASL token
*
* @param ctx to write to
* @param saslToken to write
*/
@ -255,7 +278,8 @@ public class SaslClientHandler extends ChannelDuplexHandler {
b.writeInt(saslToken.length);
b.writeBytes(saslToken, 0, saslToken.length);
ctx.writeAndFlush(b).addListener(new ChannelFutureListener() {
@Override public void operationComplete(ChannelFuture future) throws Exception {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
exceptionCaught(ctx, future.cause());
}
@ -289,7 +313,8 @@ public class SaslClientHandler extends ChannelDuplexHandler {
exceptionHandler.handle(this.retryCount++, this.random, cause);
}
@Override public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
@Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
// If not complete, try to negotiate
if (!saslClient.isComplete()) {

View File

@ -479,6 +479,11 @@
<artifactId>hamcrest-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-minikdc</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<!-- Skip the tests in this module -->

View File

@ -21,32 +21,39 @@ package org.apache.hadoop.hbase.security;
import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.getKeytabFileForTesting;
import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.getPrincipalForTesting;
import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.getSecuredConfiguration;
import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.isKerberosPropertySetted;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.junit.Assume.assumeTrue;
import java.io.File;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.hbase.HBaseTestingUtility;
import org.apache.hadoop.hbase.HConstants;
import org.apache.hadoop.hbase.ServerName;
import org.apache.hadoop.hbase.ipc.RpcClientFactory;
import org.apache.hadoop.hbase.testclassification.SecurityTests;
import org.apache.hadoop.hbase.testclassification.SmallTests;
import org.apache.hadoop.hbase.ipc.AsyncRpcClient;
import org.apache.hadoop.hbase.ipc.FifoRpcScheduler;
import org.apache.hadoop.hbase.ipc.RpcClient;
import org.apache.hadoop.hbase.ipc.RpcClientFactory;
import org.apache.hadoop.hbase.ipc.RpcClientImpl;
import org.apache.hadoop.hbase.ipc.RpcServer;
import org.apache.hadoop.hbase.ipc.RpcServerInterface;
import org.apache.hadoop.hbase.ipc.TestDelayedRpc.TestDelayedImplementation;
import org.apache.hadoop.hbase.ipc.TestDelayedRpc.TestThread;
import org.apache.hadoop.hbase.ipc.protobuf.generated.TestDelayedRpcProtos;
import org.apache.hadoop.hbase.testclassification.SecurityTests;
import org.apache.hadoop.hbase.testclassification.SmallTests;
import org.apache.hadoop.minikdc.MiniKdc;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.mockito.Mockito;
@ -57,17 +64,51 @@ import com.google.protobuf.BlockingService;
@Category({ SecurityTests.class, SmallTests.class })
public class TestSecureRPC {
public static RpcServerInterface rpcServer;
/**
* To run this test, we must specify the following system properties:
*<p>
* <b> hbase.regionserver.kerberos.principal </b>
* <p>
* <b> hbase.regionserver.keytab.file </b>
*/
private static final HBaseTestingUtility TEST_UTIL = new HBaseTestingUtility();
private static final File KEYTAB_FILE = new File(TEST_UTIL.getDataTestDir("keytab").toUri()
.getPath());
private static MiniKdc KDC;
private static String HOST;
private static String PRINCIPAL;
@BeforeClass
public static void setUp() throws Exception {
Properties conf = MiniKdc.createConf();
conf.put(MiniKdc.DEBUG, true);
KDC = new MiniKdc(conf, new File(TEST_UTIL.getDataTestDir("kdc").toUri().getPath()));
KDC.start();
HOST = InetAddress.getLocalHost().getHostName();
PRINCIPAL = "hbase/" + HOST;
KDC.createPrincipal(KEYTAB_FILE, PRINCIPAL);
HBaseKerberosUtils.setKeytabFileForTesting(KEYTAB_FILE.getAbsolutePath());
HBaseKerberosUtils.setPrincipalForTesting(PRINCIPAL + "@" + KDC.getRealm());
}
@AfterClass
public static void tearDown() throws IOException {
if (KDC != null) {
KDC.stop();
}
TEST_UTIL.cleanupTestDir();
}
@Test
public void testRpcCallWithEnabledKerberosSaslAuth() throws Exception {
assumeTrue(isKerberosPropertySetted());
public void testRpc() throws Exception {
testRpcCallWithEnabledKerberosSaslAuth(RpcClientImpl.class);
}
@Test
public void testAsyncRpc() throws Exception {
testRpcCallWithEnabledKerberosSaslAuth(AsyncRpcClient.class);
}
private void testRpcCallWithEnabledKerberosSaslAuth(Class<? extends RpcClient> rpcImplClass)
throws Exception {
String krbKeytab = getKeytabFileForTesting();
String krbPrincipal = getPrincipalForTesting();
@ -84,40 +125,42 @@ public class TestSecureRPC {
assertEquals(krbPrincipal, ugi.getUserName());
Configuration conf = getSecuredConfiguration();
conf.set(RpcClientFactory.CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, rpcImplClass.getName());
SecurityInfo securityInfoMock = Mockito.mock(SecurityInfo.class);
Mockito.when(securityInfoMock.getServerPrincipal())
.thenReturn(HBaseKerberosUtils.KRB_PRINCIPAL);
SecurityInfo.addInfo("TestDelayedService", securityInfoMock);
boolean delayReturnValue = false;
InetSocketAddress isa = new InetSocketAddress("localhost", 0);
InetSocketAddress isa = new InetSocketAddress(HOST, 0);
TestDelayedImplementation instance = new TestDelayedImplementation(delayReturnValue);
BlockingService service =
TestDelayedRpcProtos.TestDelayedService.newReflectiveBlockingService(instance);
rpcServer = new RpcServer(null, "testSecuredDelayedRpc",
Lists.newArrayList(new RpcServer.BlockingServiceAndInterface(service, null)),
isa, conf, new FifoRpcScheduler(conf, 1));
RpcServerInterface rpcServer =
new RpcServer(null, "testSecuredDelayedRpc",
Lists.newArrayList(new RpcServer.BlockingServiceAndInterface(service, null)), isa,
conf, new FifoRpcScheduler(conf, 1));
rpcServer.start();
RpcClient rpcClient = RpcClientFactory
.createClient(conf, HConstants.DEFAULT_CLUSTER_ID.toString());
RpcClient rpcClient =
RpcClientFactory.createClient(conf, HConstants.DEFAULT_CLUSTER_ID.toString());
try {
BlockingRpcChannel channel = rpcClient.createBlockingRpcChannel(
ServerName.valueOf(rpcServer.getListenerAddress().getHostName(),
rpcServer.getListenerAddress().getPort(), System.currentTimeMillis()),
User.getCurrent(), 1000);
BlockingRpcChannel channel =
rpcClient.createBlockingRpcChannel(
ServerName.valueOf(rpcServer.getListenerAddress().getHostName(), rpcServer
.getListenerAddress().getPort(), System.currentTimeMillis()), User.getCurrent(),
5000);
TestDelayedRpcProtos.TestDelayedService.BlockingInterface stub =
TestDelayedRpcProtos.TestDelayedService.newBlockingStub(channel);
List<Integer> results = new ArrayList<Integer>();
TestThread th1 = new TestThread(stub, true, results);
th1.start();
Thread.sleep(100);
th1.join();
assertEquals(0xDEADBEEF, results.get(0).intValue());
} finally {
rpcClient.close();
rpcServer.stop();
}
}
}

19
pom.xml
View File

@ -998,6 +998,13 @@
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.felix</groupId>
<artifactId>maven-bundle-plugin</artifactId>
<version>2.5.3</version>
<inherited>true</inherited>
<extensions>true</extensions>
</plugin>
</plugins>
</build>
<properties>
@ -1842,6 +1849,12 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-minikdc</artifactId>
<version>${hadoop-two.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</dependencyManagement>
</profile>
@ -2007,6 +2020,12 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-minikdc</artifactId>
<version>${hadoop-three.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</dependencyManagement>