diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java index 1daf8039136..c9ac61562d7 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java @@ -633,7 +633,8 @@ public class Client implements AutoCloseable { return false; } - private synchronized void setupConnection() throws IOException { + private synchronized void setupConnection( + UserGroupInformation ticket) throws IOException { short ioFailures = 0; short timeoutFailures = 0; while (true) { @@ -661,24 +662,26 @@ public class Client implements AutoCloseable { * client, to ensure Server matching address of the client connection * to host name in principal passed. */ - UserGroupInformation ticket = remoteId.getTicket(); + InetSocketAddress bindAddr = null; if (ticket != null && ticket.hasKerberosCredentials()) { KerberosInfo krbInfo = remoteId.getProtocol().getAnnotation(KerberosInfo.class); - if (krbInfo != null && krbInfo.clientPrincipal() != null) { - String host = - SecurityUtil.getHostFromPrincipal(remoteId.getTicket().getUserName()); - + if (krbInfo != null) { + String principal = ticket.getUserName(); + String host = SecurityUtil.getHostFromPrincipal(principal); // If host name is a valid local address then bind socket to it InetAddress localAddr = NetUtils.getLocalInetAddress(host); if (localAddr != null) { this.socket.setReuseAddress(true); - this.socket.bind(new InetSocketAddress(localAddr, 0)); + if (LOG.isDebugEnabled()) { + LOG.debug("Binding " + principal + " to " + localAddr); + } + bindAddr = new InetSocketAddress(localAddr, 0); } } } - NetUtils.connect(this.socket, server, connectionTimeout); + NetUtils.connect(this.socket, server, bindAddr, connectionTimeout); this.socket.setSoTimeout(soTimeout); return; } catch (ConnectTimeoutException toe) { @@ -762,7 +765,14 @@ public class Client implements AutoCloseable { AtomicBoolean fallbackToSimpleAuth) { if (socket != null || shouldCloseConnection.get()) { return; - } + } + UserGroupInformation ticket = remoteId.getTicket(); + if (ticket != null) { + final UserGroupInformation realUser = ticket.getRealUser(); + if (realUser != null) { + ticket = realUser; + } + } try { if (LOG.isDebugEnabled()) { LOG.debug("Connecting to "+server); @@ -774,14 +784,10 @@ public class Client implements AutoCloseable { short numRetries = 0; Random rand = null; while (true) { - setupConnection(); + setupConnection(ticket); ipcStreams = new IpcStreams(socket, maxResponseLength); writeConnectionHeader(ipcStreams); if (authProtocol == AuthProtocol.SASL) { - UserGroupInformation ticket = remoteId.getTicket(); - if (ticket.getRealUser() != null) { - ticket = ticket.getRealUser(); - } try { authMethod = ticket .doAs(new PrivilegedExceptionAction() { diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java index 3416746ab32..a4577f2923c 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java @@ -39,9 +39,12 @@ import java.io.InputStream; import java.io.OutputStream; import java.lang.reflect.Method; import java.lang.reflect.Proxy; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; +import java.net.SocketAddress; +import java.net.SocketException; import java.net.SocketTimeoutException; import java.util.ArrayList; import java.util.Collections; @@ -76,6 +79,7 @@ import org.apache.hadoop.ipc.Server.Connection; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto; import org.apache.hadoop.net.ConnectTimeoutException; import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.security.KerberosInfo; import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; @@ -1484,6 +1488,78 @@ public class TestIPC { Assert.fail("didn't get limit exceeded"); } + @Test + public void testUserBinding() throws Exception { + checkUserBinding(false); + } + + @Test + public void testProxyUserBinding() throws Exception { + checkUserBinding(true); + } + + private void checkUserBinding(boolean asProxy) throws Exception { + Socket s; + // don't attempt bind with no service host. + s = checkConnect(null, asProxy); + Mockito.verify(s, Mockito.never()).bind(Mockito.any(SocketAddress.class)); + + // don't attempt bind with service host not belonging to this host. + s = checkConnect("1.2.3.4", asProxy); + Mockito.verify(s, Mockito.never()).bind(Mockito.any(SocketAddress.class)); + + // do attempt bind when service host is this host. + InetAddress addr = InetAddress.getLocalHost(); + s = checkConnect(addr.getHostAddress(), asProxy); + Mockito.verify(s).bind(new InetSocketAddress(addr, 0)); + } + + // dummy protocol that claims to support kerberos. + @KerberosInfo(serverPrincipal = "server@REALM") + private static class TestBindingProtocol { + } + + private Socket checkConnect(String addr, boolean asProxy) throws Exception { + // create a fake ugi that claims to have kerberos credentials. + StringBuilder principal = new StringBuilder(); + principal.append("client"); + if (addr != null) { + principal.append("/").append(addr); + } + principal.append("@REALM"); + UserGroupInformation ugi = + spy(UserGroupInformation.createRemoteUser(principal.toString())); + Mockito.doReturn(true).when(ugi).hasKerberosCredentials(); + if (asProxy) { + ugi = UserGroupInformation.createProxyUser("proxy", ugi); + } + + // create a mock socket that throws on connect. + SocketException expectedConnectEx = + new SocketException("Expected connect failure"); + Socket s = Mockito.mock(Socket.class); + SocketFactory mockFactory = Mockito.mock(SocketFactory.class); + Mockito.doReturn(s).when(mockFactory).createSocket(); + doThrow(expectedConnectEx).when(s).connect( + Mockito.any(SocketAddress.class), Mockito.anyInt()); + + // do a dummy call and expect it to throw an exception on connect. + // tests should verify if/how a bind occurred. + try (Client client = new Client(LongWritable.class, conf, mockFactory)) { + final InetSocketAddress sockAddr = new InetSocketAddress(0); + final LongWritable param = new LongWritable(RANDOM.nextLong()); + final ConnectionId remoteId = new ConnectionId( + sockAddr, TestBindingProtocol.class, ugi, 0, + RetryPolicies.TRY_ONCE_THEN_FAIL, conf); + client.call(RPC.RpcKind.RPC_BUILTIN, param, remoteId, null); + fail("call didn't throw connect exception"); + } catch (SocketException se) { + // ipc layer re-wraps exceptions, so check the cause. + Assert.assertSame(expectedConnectEx, se.getCause()); + } + return s; + } + private void doIpcVersionTest( byte[] requestData, byte[] expectedResponse) throws IOException {