diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannel.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannel.java index b3e01f3e093..84142901283 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannel.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannel.java @@ -189,10 +189,11 @@ public class AsyncRpcChannel { if (ticket == null) { throw new FatalConnectionException("ticket/user is null"); } + final UserGroupInformation realTicket = ticket; saslHandler = ticket.doAs(new PrivilegedExceptionAction() { @Override public SaslClientHandler run() throws IOException { - return getSaslHandler(bootstrap); + return getSaslHandler(realTicket, bootstrap); } }); if (saslHandler != null) { @@ -244,20 +245,21 @@ public class AsyncRpcChannel { /** * Get SASL handler - * * @param bootstrap to reconnect to * @return new SASL handler * @throws java.io.IOException if handler failed to create */ - private SaslClientHandler getSaslHandler(final Bootstrap bootstrap) throws IOException { - return new SaslClientHandler(authMethod, token, serverPrincipal, client.fallbackAllowed, - client.conf.get("hbase.rpc.protection", SaslUtil.QualityOfProtection.AUTHENTICATION.name() - .toLowerCase()), new SaslClientHandler.SaslExceptionHandler() { + private SaslClientHandler getSaslHandler(final UserGroupInformation realTicket, + final Bootstrap bootstrap) throws IOException { + return new SaslClientHandler(realTicket, authMethod, token, serverPrincipal, + client.fallbackAllowed, client.conf.get("hbase.rpc.protection", + SaslUtil.QualityOfProtection.AUTHENTICATION.name().toLowerCase()), + new SaslClientHandler.SaslExceptionHandler() { @Override public void handle(int retryCount, Random random, Throwable cause) { try { // Handle Sasl failure. Try to potentially get new credentials - handleSaslConnectionFailure(retryCount, cause, ticket.getUGI()); + handleSaslConnectionFailure(retryCount, cause, realTicket); // Try to reconnect AsyncRpcClient.WHEEL_TIMER.newTimeout(new TimerTask() { diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java index 50445c144c7..1be59bc7a62 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java @@ -24,10 +24,12 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hbase.classification.InterfaceAudience; import org.apache.hadoop.ipc.RemoteException; +import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; @@ -35,8 +37,10 @@ import javax.security.auth.callback.CallbackHandler; import javax.security.sasl.Sasl; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; + import java.io.IOException; import java.nio.charset.Charset; +import java.security.PrivilegedExceptionAction; import java.util.Random; /** @@ -48,6 +52,8 @@ public class SaslClientHandler extends ChannelDuplexHandler { private final boolean fallbackAllowed; + private final UserGroupInformation ticket; + /** * Used for client or server's token to send or receive from each other. */ @@ -63,6 +69,7 @@ public class SaslClientHandler extends ChannelDuplexHandler { /** * Constructor * + * @param ticket the ugi * @param method auth method * @param token for Sasl * @param serverPrincipal Server's Kerberos principal name @@ -72,10 +79,11 @@ public class SaslClientHandler extends ChannelDuplexHandler { * @param successfulConnectHandler handler for succesful connects * @throws java.io.IOException if handler could not be created */ - public SaslClientHandler(AuthMethod method, Token token, - String serverPrincipal, boolean fallbackAllowed, String rpcProtection, - SaslExceptionHandler exceptionHandler, SaslSuccessfulConnectHandler successfulConnectHandler) - throws IOException { + public SaslClientHandler(UserGroupInformation ticket, AuthMethod method, + Token token, String serverPrincipal, boolean fallbackAllowed, + String rpcProtection, SaslExceptionHandler exceptionHandler, + SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException { + this.ticket = ticket; this.fallbackAllowed = fallbackAllowed; this.exceptionHandler = exceptionHandler; @@ -109,8 +117,9 @@ public class SaslClientHandler extends ChannelDuplexHandler { default: throw new IOException("Unknown authentication method " + method); } - if (saslClient == null) + if (saslClient == null) { throw new IOException("Unable to find SASL client implementation"); + } } /** @@ -144,14 +153,26 @@ public class SaslClientHandler extends ChannelDuplexHandler { null); } - @Override public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { saslClient.dispose(); } - @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { - this.saslToken = new byte[0]; + private byte[] evaluateChallenge(final byte[] challenge) throws Exception { + return ticket.doAs(new PrivilegedExceptionAction() { + + @Override + public byte[] run() throws Exception { + return saslClient.evaluateChallenge(challenge); + } + }); + } + + @Override + public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { + saslToken = new byte[0]; if (saslClient.hasInitialResponse()) { - saslToken = saslClient.evaluateChallenge(saslToken); + saslToken = evaluateChallenge(saslToken); } if (saslToken != null) { writeSaslToken(ctx, saslToken); @@ -161,7 +182,8 @@ public class SaslClientHandler extends ChannelDuplexHandler { } } - @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { ByteBuf in = (ByteBuf) msg; // If not complete, try to negotiate @@ -187,15 +209,17 @@ public class SaslClientHandler extends ChannelDuplexHandler { } } saslToken = new byte[len]; - if (LOG.isDebugEnabled()) + if (LOG.isDebugEnabled()) { LOG.debug("Will read input token of size " + saslToken.length + " for processing by initSASLContext"); + } in.readBytes(saslToken); - saslToken = saslClient.evaluateChallenge(saslToken); + saslToken = evaluateChallenge(saslToken); if (saslToken != null) { - if (LOG.isDebugEnabled()) + if (LOG.isDebugEnabled()) { LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext."); + } writeSaslToken(ctx, saslToken); } } @@ -246,8 +270,7 @@ public class SaslClientHandler extends ChannelDuplexHandler { /** * Write SASL token - * - * @param ctx to write to + * @param ctx to write to * @param saslToken to write */ private void writeSaslToken(final ChannelHandlerContext ctx, byte[] saslToken) { @@ -255,7 +278,8 @@ public class SaslClientHandler extends ChannelDuplexHandler { b.writeInt(saslToken.length); b.writeBytes(saslToken, 0, saslToken.length); ctx.writeAndFlush(b).addListener(new ChannelFutureListener() { - @Override public void operationComplete(ChannelFuture future) throws Exception { + @Override + public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { exceptionCaught(ctx, future.cause()); } @@ -289,7 +313,8 @@ public class SaslClientHandler extends ChannelDuplexHandler { exceptionHandler.handle(this.retryCount++, this.random, cause); } - @Override public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + @Override + public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { // If not complete, try to negotiate if (!saslClient.isComplete()) { diff --git a/hbase-server/pom.xml b/hbase-server/pom.xml index c9ba4dadfeb..6df52936bf6 100644 --- a/hbase-server/pom.xml +++ b/hbase-server/pom.xml @@ -479,6 +479,11 @@ hamcrest-core test + + org.apache.hadoop + hadoop-minikdc + test + diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java index b28a1ef8fce..8ac38fa93df 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java @@ -21,32 +21,39 @@ package org.apache.hadoop.hbase.security; import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.getKeytabFileForTesting; import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.getPrincipalForTesting; import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.getSecuredConfiguration; -import static org.apache.hadoop.hbase.security.HBaseKerberosUtils.isKerberosPropertySetted; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; -import static org.junit.Assume.assumeTrue; +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.List; +import java.util.Properties; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.CommonConfigurationKeys; +import org.apache.hadoop.hbase.HBaseTestingUtility; import org.apache.hadoop.hbase.HConstants; import org.apache.hadoop.hbase.ServerName; -import org.apache.hadoop.hbase.ipc.RpcClientFactory; -import org.apache.hadoop.hbase.testclassification.SecurityTests; -import org.apache.hadoop.hbase.testclassification.SmallTests; +import org.apache.hadoop.hbase.ipc.AsyncRpcClient; import org.apache.hadoop.hbase.ipc.FifoRpcScheduler; import org.apache.hadoop.hbase.ipc.RpcClient; +import org.apache.hadoop.hbase.ipc.RpcClientFactory; +import org.apache.hadoop.hbase.ipc.RpcClientImpl; import org.apache.hadoop.hbase.ipc.RpcServer; import org.apache.hadoop.hbase.ipc.RpcServerInterface; import org.apache.hadoop.hbase.ipc.TestDelayedRpc.TestDelayedImplementation; import org.apache.hadoop.hbase.ipc.TestDelayedRpc.TestThread; import org.apache.hadoop.hbase.ipc.protobuf.generated.TestDelayedRpcProtos; +import org.apache.hadoop.hbase.testclassification.SecurityTests; +import org.apache.hadoop.hbase.testclassification.SmallTests; +import org.apache.hadoop.minikdc.MiniKdc; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.experimental.categories.Category; import org.mockito.Mockito; @@ -55,19 +62,53 @@ import com.google.common.collect.Lists; import com.google.protobuf.BlockingRpcChannel; import com.google.protobuf.BlockingService; -@Category({SecurityTests.class, SmallTests.class}) +@Category({ SecurityTests.class, SmallTests.class }) public class TestSecureRPC { - public static RpcServerInterface rpcServer; - /** - * To run this test, we must specify the following system properties: - *

- * hbase.regionserver.kerberos.principal - *

- * hbase.regionserver.keytab.file - */ + + private static final HBaseTestingUtility TEST_UTIL = new HBaseTestingUtility(); + + private static final File KEYTAB_FILE = new File(TEST_UTIL.getDataTestDir("keytab").toUri() + .getPath()); + + private static MiniKdc KDC; + + private static String HOST; + + private static String PRINCIPAL; + + @BeforeClass + public static void setUp() throws Exception { + Properties conf = MiniKdc.createConf(); + conf.put(MiniKdc.DEBUG, true); + KDC = new MiniKdc(conf, new File(TEST_UTIL.getDataTestDir("kdc").toUri().getPath())); + KDC.start(); + HOST = InetAddress.getLocalHost().getHostName(); + PRINCIPAL = "hbase/" + HOST; + KDC.createPrincipal(KEYTAB_FILE, PRINCIPAL); + HBaseKerberosUtils.setKeytabFileForTesting(KEYTAB_FILE.getAbsolutePath()); + HBaseKerberosUtils.setPrincipalForTesting(PRINCIPAL + "@" + KDC.getRealm()); + } + + @AfterClass + public static void tearDown() throws IOException { + if (KDC != null) { + KDC.stop(); + } + TEST_UTIL.cleanupTestDir(); + } + @Test - public void testRpcCallWithEnabledKerberosSaslAuth() throws Exception { - assumeTrue(isKerberosPropertySetted()); + public void testRpc() throws Exception { + testRpcCallWithEnabledKerberosSaslAuth(RpcClientImpl.class); + } + + @Test + public void testAsyncRpc() throws Exception { + testRpcCallWithEnabledKerberosSaslAuth(AsyncRpcClient.class); + } + + private void testRpcCallWithEnabledKerberosSaslAuth(Class rpcImplClass) + throws Exception { String krbKeytab = getKeytabFileForTesting(); String krbPrincipal = getPrincipalForTesting(); @@ -84,40 +125,42 @@ public class TestSecureRPC { assertEquals(krbPrincipal, ugi.getUserName()); Configuration conf = getSecuredConfiguration(); - + conf.set(RpcClientFactory.CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, rpcImplClass.getName()); SecurityInfo securityInfoMock = Mockito.mock(SecurityInfo.class); Mockito.when(securityInfoMock.getServerPrincipal()) - .thenReturn(HBaseKerberosUtils.KRB_PRINCIPAL); + .thenReturn(HBaseKerberosUtils.KRB_PRINCIPAL); SecurityInfo.addInfo("TestDelayedService", securityInfoMock); boolean delayReturnValue = false; - InetSocketAddress isa = new InetSocketAddress("localhost", 0); + InetSocketAddress isa = new InetSocketAddress(HOST, 0); TestDelayedImplementation instance = new TestDelayedImplementation(delayReturnValue); BlockingService service = TestDelayedRpcProtos.TestDelayedService.newReflectiveBlockingService(instance); - rpcServer = new RpcServer(null, "testSecuredDelayedRpc", - Lists.newArrayList(new RpcServer.BlockingServiceAndInterface(service, null)), - isa, conf, new FifoRpcScheduler(conf, 1)); + RpcServerInterface rpcServer = + new RpcServer(null, "testSecuredDelayedRpc", + Lists.newArrayList(new RpcServer.BlockingServiceAndInterface(service, null)), isa, + conf, new FifoRpcScheduler(conf, 1)); rpcServer.start(); - RpcClient rpcClient = RpcClientFactory - .createClient(conf, HConstants.DEFAULT_CLUSTER_ID.toString()); + RpcClient rpcClient = + RpcClientFactory.createClient(conf, HConstants.DEFAULT_CLUSTER_ID.toString()); try { - BlockingRpcChannel channel = rpcClient.createBlockingRpcChannel( - ServerName.valueOf(rpcServer.getListenerAddress().getHostName(), - rpcServer.getListenerAddress().getPort(), System.currentTimeMillis()), - User.getCurrent(), 1000); + BlockingRpcChannel channel = + rpcClient.createBlockingRpcChannel( + ServerName.valueOf(rpcServer.getListenerAddress().getHostName(), rpcServer + .getListenerAddress().getPort(), System.currentTimeMillis()), User.getCurrent(), + 5000); TestDelayedRpcProtos.TestDelayedService.BlockingInterface stub = - TestDelayedRpcProtos.TestDelayedService.newBlockingStub(channel); + TestDelayedRpcProtos.TestDelayedService.newBlockingStub(channel); List results = new ArrayList(); TestThread th1 = new TestThread(stub, true, results); th1.start(); - Thread.sleep(100); th1.join(); assertEquals(0xDEADBEEF, results.get(0).intValue()); } finally { rpcClient.close(); + rpcServer.stop(); } } } \ No newline at end of file diff --git a/pom.xml b/pom.xml index 399850e0c5a..b0b26814d31 100644 --- a/pom.xml +++ b/pom.xml @@ -998,6 +998,13 @@ + + org.apache.felix + maven-bundle-plugin + 2.5.3 + true + true + @@ -1842,6 +1849,12 @@ + + org.apache.hadoop + hadoop-minikdc + ${hadoop-two.version} + test + @@ -2007,6 +2020,12 @@ + + org.apache.hadoop + hadoop-minikdc + ${hadoop-three.version} + test +