diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java index 8844946ef63..5fbf64f1760 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java @@ -353,10 +353,9 @@ public abstract class Server { */ public static InetAddress getRemoteIp() { Call call = CurCall.get(); - return (call != null && call.connection != null) ? call.connection - .getHostInetAddress() : null; + return (call != null ) ? call.getHostInetAddress() : null; } - + /** * Returns the clientId from the current RPC request */ @@ -379,10 +378,9 @@ public abstract class Server { */ public static UserGroupInformation getRemoteUser() { Call call = CurCall.get(); - return (call != null && call.connection != null) ? call.connection.user - : null; + return (call != null) ? call.getRemoteUser() : null; } - + /** Return true if the invocation was through an RPC. */ public static boolean isRpcInvocation() { @@ -482,7 +480,7 @@ public abstract class Server { if ((rpcMetrics.getProcessingSampleCount() > minSampleSize) && (processingTime > threeSigma)) { if(LOG.isWarnEnabled()) { - String client = CurCall.get().connection.toString(); + String client = CurCall.get().toString(); LOG.warn( "Slow RPC : " + methodName + " took " + processingTime + " milliseconds to process from client " + client); @@ -656,62 +654,65 @@ public abstract class Server { CommonConfigurationKeys.IPC_BACKOFF_ENABLE_DEFAULT); } - /** A call queued for handling. */ - public static class Call implements Schedulable { - private final int callId; // the client's call id - private final int retryCount; // the retry count of the call - private final Writable rpcRequest; // Serialized Rpc request from client - private final Connection connection; // connection to client - private long timestamp; // time received when response is null - // time served when response is not null - private ByteBuffer rpcResponse; // the response for this call + /** A generic call queued for handling. */ + public static class Call implements Schedulable, + PrivilegedExceptionAction { + final int callId; // the client's call id + final int retryCount; // the retry count of the call + long timestamp; // time received when response is null + // time served when response is not null private AtomicInteger responseWaitCount = new AtomicInteger(1); - private final RPC.RpcKind rpcKind; - private final byte[] clientId; + final RPC.RpcKind rpcKind; + final byte[] clientId; private final TraceScope traceScope; // the HTrace scope on the server side private final CallerContext callerContext; // the call context private int priorityLevel; // the priority level assigned by scheduler, 0 by default - private Call(Call call) { - this(call.callId, call.retryCount, call.rpcRequest, call.connection, - call.rpcKind, call.clientId, call.traceScope, call.callerContext); + Call(Call call) { + this(call.callId, call.retryCount, call.rpcKind, call.clientId, + call.traceScope, call.callerContext); } - public Call(int id, int retryCount, Writable param, - Connection connection) { - this(id, retryCount, param, connection, RPC.RpcKind.RPC_BUILTIN, - RpcConstants.DUMMY_CLIENT_ID); + Call(int id, int retryCount, RPC.RpcKind kind, byte[] clientId) { + this(id, retryCount, kind, clientId, null, null); } - public Call(int id, int retryCount, Writable param, Connection connection, + @VisibleForTesting // primarily TestNamenodeRetryCache + public Call(int id, int retryCount, Void ignore1, Void ignore2, RPC.RpcKind kind, byte[] clientId) { - this(id, retryCount, param, connection, kind, clientId, null, null); + this(id, retryCount, kind, clientId, null, null); } - public Call(int id, int retryCount, Writable param, Connection connection, - RPC.RpcKind kind, byte[] clientId, TraceScope traceScope, - CallerContext callerContext) { + Call(int id, int retryCount, RPC.RpcKind kind, byte[] clientId, + TraceScope traceScope, CallerContext callerContext) { this.callId = id; this.retryCount = retryCount; - this.rpcRequest = param; - this.connection = connection; this.timestamp = Time.now(); - this.rpcResponse = null; this.rpcKind = kind; this.clientId = clientId; this.traceScope = traceScope; this.callerContext = callerContext; } - + @Override public String toString() { - return rpcRequest + " from " + connection + " Call#" + callId + " Retry#" - + retryCount; + return "Call#" + callId + " Retry#" + retryCount; } - public void setResponse(ByteBuffer response) { - this.rpcResponse = response; + public Void run() throws Exception { + return null; + } + // should eventually be abstract but need to avoid breaking tests + public UserGroupInformation getRemoteUser() { + return null; + } + public InetAddress getHostInetAddress() { + return null; + } + public String getHostAddress() { + InetAddress addr = getHostInetAddress(); + return (addr != null) ? addr.getHostAddress() : null; } /** @@ -723,34 +724,36 @@ public abstract class Server { */ @InterfaceStability.Unstable @InterfaceAudience.LimitedPrivate({"HDFS"}) - public void postponeResponse() { + public final void postponeResponse() { int count = responseWaitCount.incrementAndGet(); assert count > 0 : "response has already been sent"; } @InterfaceStability.Unstable @InterfaceAudience.LimitedPrivate({"HDFS"}) - public void sendResponse() throws IOException { + public final void sendResponse() throws IOException { int count = responseWaitCount.decrementAndGet(); assert count >= 0 : "response has already been sent"; if (count == 0) { - connection.sendResponse(this); + doResponse(null); } } @InterfaceStability.Unstable @InterfaceAudience.LimitedPrivate({"HDFS"}) - public void abortResponse(Throwable t) throws IOException { + public final void abortResponse(Throwable t) throws IOException { // don't send response if the call was already sent or aborted. if (responseWaitCount.getAndSet(-1) > 0) { - connection.abortResponse(this, t); + doResponse(t); } } + void doResponse(Throwable t) throws IOException {} + // For Schedulable @Override public UserGroupInformation getUserGroupInformation() { - return connection.user; + return getRemoteUser(); } @Override @@ -763,6 +766,114 @@ public abstract class Server { } } + /** A RPC extended call queued for handling. */ + private class RpcCall extends Call { + final Connection connection; // connection to client + final Writable rpcRequest; // Serialized Rpc request from client + ByteBuffer rpcResponse; // the response for this call + + RpcCall(RpcCall call) { + super(call); + this.connection = call.connection; + this.rpcRequest = call.rpcRequest; + } + + RpcCall(Connection connection, int id) { + this(connection, id, RpcConstants.INVALID_RETRY_COUNT); + } + + RpcCall(Connection connection, int id, int retryCount) { + this(connection, id, retryCount, null, + RPC.RpcKind.RPC_BUILTIN, RpcConstants.DUMMY_CLIENT_ID, + null, null); + } + + RpcCall(Connection connection, int id, int retryCount, + Writable param, RPC.RpcKind kind, byte[] clientId, + TraceScope traceScope, CallerContext context) { + super(id, retryCount, kind, clientId, traceScope, context); + this.connection = connection; + this.rpcRequest = param; + } + + @Override + public UserGroupInformation getRemoteUser() { + return connection.user; + } + + @Override + public InetAddress getHostInetAddress() { + return connection.getHostInetAddress(); + } + + @Override + public Void run() throws Exception { + if (!connection.channel.isOpen()) { + Server.LOG.info(Thread.currentThread().getName() + ": skipped " + this); + return null; + } + String errorClass = null; + String error = null; + RpcStatusProto returnStatus = RpcStatusProto.SUCCESS; + RpcErrorCodeProto detailedErr = null; + Writable value = null; + + try { + value = call( + rpcKind, connection.protocolName, rpcRequest, timestamp); + } catch (Throwable e) { + if (e instanceof UndeclaredThrowableException) { + e = e.getCause(); + } + logException(Server.LOG, e, this); + if (e instanceof RpcServerException) { + RpcServerException rse = ((RpcServerException)e); + returnStatus = rse.getRpcStatusProto(); + detailedErr = rse.getRpcErrorCodeProto(); + } else { + returnStatus = RpcStatusProto.ERROR; + detailedErr = RpcErrorCodeProto.ERROR_APPLICATION; + } + errorClass = e.getClass().getName(); + error = StringUtils.stringifyException(e); + // Remove redundant error class name from the beginning of the + // stack trace + String exceptionHdr = errorClass + ": "; + if (error.startsWith(exceptionHdr)) { + error = error.substring(exceptionHdr.length()); + } + } + setupResponse(this, returnStatus, detailedErr, + value, errorClass, error); + sendResponse(); + return null; + } + + void setResponse(ByteBuffer response) throws IOException { + this.rpcResponse = response; + } + + @Override + void doResponse(Throwable t) throws IOException { + RpcCall call = this; + if (t != null) { + // clone the call to prevent a race with another thread stomping + // on the response while being sent. the original call is + // effectively discarded since the wait count won't hit zero + call = new RpcCall(this); + setupResponse(call, + RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER, + null, t.getClass().getName(), StringUtils.stringifyException(t)); + } + connection.sendResponse(call); + } + + @Override + public String toString() { + return super.toString() + " " + rpcRequest + " from " + connection; + } + } + /** Listens on the socket. Creates jobs for the handler threads*/ private class Listener extends Thread { @@ -1093,22 +1204,22 @@ public abstract class Server { if(LOG.isDebugEnabled()) { LOG.debug("Checking for old call responses."); } - ArrayList calls; + ArrayList calls; // get the list of channels from list of keys. synchronized (writeSelector.keys()) { - calls = new ArrayList(writeSelector.keys().size()); + calls = new ArrayList(writeSelector.keys().size()); iter = writeSelector.keys().iterator(); while (iter.hasNext()) { SelectionKey key = iter.next(); - Call call = (Call)key.attachment(); + RpcCall call = (RpcCall)key.attachment(); if (call != null && key.channel() == call.connection.channel) { calls.add(call); } } } - - for(Call call : calls) { + + for (RpcCall call : calls) { doPurge(call, now); } } catch (OutOfMemoryError e) { @@ -1126,7 +1237,7 @@ public abstract class Server { } private void doAsyncWrite(SelectionKey key) throws IOException { - Call call = (Call)key.attachment(); + RpcCall call = (RpcCall)key.attachment(); if (call == null) { return; } @@ -1154,10 +1265,10 @@ public abstract class Server { // Remove calls that have been pending in the responseQueue // for a long time. // - private void doPurge(Call call, long now) { - LinkedList responseQueue = call.connection.responseQueue; + private void doPurge(RpcCall call, long now) { + LinkedList responseQueue = call.connection.responseQueue; synchronized (responseQueue) { - Iterator iter = responseQueue.listIterator(0); + Iterator iter = responseQueue.listIterator(0); while (iter.hasNext()) { call = iter.next(); if (now > call.timestamp + PURGE_INTERVAL) { @@ -1171,12 +1282,12 @@ public abstract class Server { // Processes one response. Returns true if there are no more pending // data for this channel. // - private boolean processResponse(LinkedList responseQueue, + private boolean processResponse(LinkedList responseQueue, boolean inHandler) throws IOException { boolean error = true; boolean done = false; // there is more data for this channel. int numElements = 0; - Call call = null; + RpcCall call = null; try { synchronized (responseQueue) { // @@ -1259,7 +1370,7 @@ public abstract class Server { // // Enqueue a response from the application. // - void doRespond(Call call) throws IOException { + void doRespond(RpcCall call) throws IOException { synchronized (call.connection.responseQueue) { // must only wrap before adding to the responseQueue to prevent // postponed responses from being encrypted and sent out of order. @@ -1357,7 +1468,7 @@ public abstract class Server { private SocketChannel channel; private ByteBuffer data; private ByteBuffer dataLengthBuffer; - private LinkedList responseQueue; + private LinkedList responseQueue; // number of outstanding rpcs private AtomicInteger rpcCount = new AtomicInteger(); private long lastContact; @@ -1384,8 +1495,8 @@ public abstract class Server { public UserGroupInformation attemptingUser = null; // user name before auth // Fake 'call' for failed authorization response - private final Call authFailedCall = new Call(AUTHORIZATION_FAILED_CALL_ID, - RpcConstants.INVALID_RETRY_COUNT, null, this); + private final RpcCall authFailedCall = + new RpcCall(this, AUTHORIZATION_FAILED_CALL_ID); private boolean sentNegotiate = false; private boolean useWrap = false; @@ -1405,7 +1516,7 @@ public abstract class Server { this.hostAddress = addr.getHostAddress(); } this.remotePort = socket.getPort(); - this.responseQueue = new LinkedList(); + this.responseQueue = new LinkedList(); if (socketSendBufferSize != 0) { try { socket.setSendBufferSize(socketSendBufferSize); @@ -1666,8 +1777,7 @@ public abstract class Server { } private void doSaslReply(Message message) throws IOException { - final Call saslCall = new Call(AuthProtocol.SASL.callId, - RpcConstants.INVALID_RETRY_COUNT, null, this); + final RpcCall saslCall = new RpcCall(this, AuthProtocol.SASL.callId); setupResponse(saslCall, RpcStatusProto.SUCCESS, null, RpcWritable.wrap(message), null, null); @@ -1853,23 +1963,20 @@ public abstract class Server { if (clientVersion >= 9) { // Versions >>9 understand the normal response - Call fakeCall = new Call(-1, RpcConstants.INVALID_RETRY_COUNT, null, - this); + RpcCall fakeCall = new RpcCall(this, -1); setupResponse(fakeCall, RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_VERSION_MISMATCH, null, VersionMismatch.class.getName(), errMsg); fakeCall.sendResponse(); } else if (clientVersion >= 3) { - Call fakeCall = new Call(-1, RpcConstants.INVALID_RETRY_COUNT, null, - this); + RpcCall fakeCall = new RpcCall(this, -1); // Versions 3 to 8 use older response setupResponseOldVersionFatal(buffer, fakeCall, null, VersionMismatch.class.getName(), errMsg); fakeCall.sendResponse(); } else if (clientVersion == 2) { // Hadoop 0.18.3 - Call fakeCall = new Call(0, RpcConstants.INVALID_RETRY_COUNT, null, - this); + RpcCall fakeCall = new RpcCall(this, 0); DataOutputStream out = new DataOutputStream(buffer); out.writeInt(0); // call ID out.writeBoolean(true); // error @@ -1881,7 +1988,7 @@ public abstract class Server { } private void setupHttpRequestOnIpcPortResponse() throws IOException { - Call fakeCall = new Call(0, RpcConstants.INVALID_RETRY_COUNT, null, this); + RpcCall fakeCall = new RpcCall(this, 0); fakeCall.setResponse(ByteBuffer.wrap( RECEIVED_HTTP_REQ_RESPONSE.getBytes(StandardCharsets.UTF_8))); fakeCall.sendResponse(); @@ -2018,7 +2125,7 @@ public abstract class Server { } } catch (WrappedRpcServerException wrse) { // inform client of error Throwable ioe = wrse.getCause(); - final Call call = new Call(callId, retry, null, this); + final RpcCall call = new RpcCall(this, callId, retry); setupResponse(call, RpcStatusProto.FATAL, wrse.getRpcErrorCodeProto(), null, ioe.getClass().getName(), ioe.getMessage()); @@ -2115,8 +2222,9 @@ public abstract class Server { .build(); } - Call call = new Call(header.getCallId(), header.getRetryCount(), - rpcRequest, this, ProtoUtil.convert(header.getRpcKind()), + RpcCall call = new RpcCall(this, header.getCallId(), + header.getRetryCount(), rpcRequest, + ProtoUtil.convert(header.getRpcKind()), header.getClientId().toByteArray(), traceScope, callerContext); // Save the priority level assignment by the scheduler @@ -2239,21 +2347,10 @@ public abstract class Server { } } - private void sendResponse(Call call) throws IOException { + private void sendResponse(RpcCall call) throws IOException { responder.doRespond(call); } - private void abortResponse(Call call, Throwable t) throws IOException { - // clone the call to prevent a race with the other thread stomping - // on the response while being sent. the original call is - // effectively discarded since the wait count won't hit zero - call = new Call(call); - setupResponse(call, - RpcStatusProto.FATAL, RpcErrorCodeProto.ERROR_RPC_SERVER, - null, t.getClass().getName(), StringUtils.stringifyException(t)); - call.sendResponse(); - } - /** * Get service class for connection * @return the serviceClass @@ -2304,16 +2401,6 @@ public abstract class Server { if (LOG.isDebugEnabled()) { LOG.debug(Thread.currentThread().getName() + ": " + call + " for RpcKind " + call.rpcKind); } - if (!call.connection.channel.isOpen()) { - LOG.info(Thread.currentThread().getName() + ": skipped " + call); - continue; - } - String errorClass = null; - String error = null; - RpcStatusProto returnStatus = RpcStatusProto.SUCCESS; - RpcErrorCodeProto detailedErr = null; - Writable value = null; - CurCall.set(call); if (call.traceScope != null) { call.traceScope.reattach(); @@ -2322,53 +2409,11 @@ public abstract class Server { } // always update the current call context CallerContext.setCurrent(call.callerContext); - - try { - // Make the call as the user via Subject.doAs, thus associating - // the call with the Subject - if (call.connection.user == null) { - value = call(call.rpcKind, call.connection.protocolName, call.rpcRequest, - call.timestamp); - } else { - value = - call.connection.user.doAs - (new PrivilegedExceptionAction() { - @Override - public Writable run() throws Exception { - // make the call - return call(call.rpcKind, call.connection.protocolName, - call.rpcRequest, call.timestamp); - - } - } - ); - } - } catch (Throwable e) { - if (e instanceof UndeclaredThrowableException) { - e = e.getCause(); - } - logException(LOG, e, call); - if (e instanceof RpcServerException) { - RpcServerException rse = ((RpcServerException)e); - returnStatus = rse.getRpcStatusProto(); - detailedErr = rse.getRpcErrorCodeProto(); - } else { - returnStatus = RpcStatusProto.ERROR; - detailedErr = RpcErrorCodeProto.ERROR_APPLICATION; - } - errorClass = e.getClass().getName(); - error = StringUtils.stringifyException(e); - // Remove redundant error class name from the beginning of the stack trace - String exceptionHdr = errorClass + ": "; - if (error.startsWith(exceptionHdr)) { - error = error.substring(exceptionHdr.length()); - } - } - CurCall.set(null); - synchronized (call.connection.responseQueue) { - setupResponse(call, returnStatus, detailedErr, - value, errorClass, error); - call.sendResponse(); + UserGroupInformation remoteUser = call.getRemoteUser(); + if (remoteUser != null) { + remoteUser.doAs(call); + } else { + call.run(); } } catch (InterruptedException e) { if (running) { // unexpected -- log it @@ -2385,6 +2430,7 @@ public abstract class Server { StringUtils.stringifyException(e)); } } finally { + CurCall.set(null); IOUtils.cleanup(LOG, traceScope); } } @@ -2586,7 +2632,7 @@ public abstract class Server { * @throws IOException */ private void setupResponse( - Call call, RpcStatusProto status, RpcErrorCodeProto erCode, + RpcCall call, RpcStatusProto status, RpcErrorCodeProto erCode, Writable rv, String errorClass, String error) throws IOException { RpcResponseHeaderProto.Builder headerBuilder = @@ -2620,7 +2666,7 @@ public abstract class Server { } } - private void setupResponse(Call call, + private void setupResponse(RpcCall call, RpcResponseHeaderProto header, Writable rv) throws IOException { ResponseBuffer buf = responseBuffer.get().reset(); try { @@ -2654,7 +2700,7 @@ public abstract class Server { * @throws IOException */ private void setupResponseOldVersionFatal(ByteArrayOutputStream response, - Call call, + RpcCall call, Writable rv, String errorClass, String error) throws IOException { final int OLD_VERSION_FATAL_STATUS = -1; @@ -2667,7 +2713,7 @@ public abstract class Server { call.setResponse(ByteBuffer.wrap(response.toByteArray())); } - private void wrapWithSasl(Call call) throws IOException { + private void wrapWithSasl(RpcCall call) throws IOException { if (call.connection.saslServer != null) { byte[] token = call.rpcResponse.array(); // synchronization may be needed since there can be multiple Handler