HADOOP-14034. Allow ipc layer exceptions to selectively close connections. Contributed by Daryn Sharp.

(cherry picked from commit d008b55153)

Conflicts:
	hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java
This commit is contained in:
Kihwal Lee 2017-02-09 11:04:29 -06:00
parent bb98856af9
commit 47bbe431bf
2 changed files with 235 additions and 97 deletions

View File

@ -1127,20 +1127,16 @@ public abstract class Server {
LOG.info(Thread.currentThread().getName() + ": readAndProcess caught InterruptedException", ieo); LOG.info(Thread.currentThread().getName() + ": readAndProcess caught InterruptedException", ieo);
throw ieo; throw ieo;
} catch (Exception e) { } catch (Exception e) {
// Do not log WrappedRpcServerExceptionSuppressed. // Any exceptions that reach here are fatal unexpected internal errors
if (!(e instanceof WrappedRpcServerExceptionSuppressed)) { // that could not be sent to the client.
// A WrappedRpcServerException is an exception that has been sent
// to the client, so the stacktrace is unnecessary; any other
// exceptions are unexpected internal server errors and thus the
// stacktrace should be logged.
LOG.info(Thread.currentThread().getName() + LOG.info(Thread.currentThread().getName() +
": readAndProcess from client " + c.getHostAddress() + ": readAndProcess from client " + c +
" threw exception [" + e + "]", " threw exception [" + e + "]", e);
(e instanceof WrappedRpcServerException) ? null : e);
}
count = -1; //so that the (count < 0) block is executed count = -1; //so that the (count < 0) block is executed
} }
if (count < 0) { // setupResponse will signal the connection should be closed when a
// fatal response is sent.
if (count < 0 || c.shouldClose()) {
closeConnection(c); closeConnection(c);
c = null; c = null;
} }
@ -1468,16 +1464,20 @@ public abstract class Server {
* unnecessary stack trace logging if it's not an internal server error. * unnecessary stack trace logging if it's not an internal server error.
*/ */
@SuppressWarnings("serial") @SuppressWarnings("serial")
private static class WrappedRpcServerException extends RpcServerException { private static class FatalRpcServerException extends RpcServerException {
private final RpcErrorCodeProto errCode; private final RpcErrorCodeProto errCode;
public WrappedRpcServerException(RpcErrorCodeProto errCode, IOException ioe) { public FatalRpcServerException(RpcErrorCodeProto errCode, IOException ioe) {
super(ioe.toString(), ioe); super(ioe.toString(), ioe);
this.errCode = errCode; this.errCode = errCode;
} }
public WrappedRpcServerException(RpcErrorCodeProto errCode, String message) { public FatalRpcServerException(RpcErrorCodeProto errCode, String message) {
this(errCode, new RpcServerException(message)); this(errCode, new RpcServerException(message));
} }
@Override @Override
public RpcStatusProto getRpcStatusProto() {
return RpcStatusProto.FATAL;
}
@Override
public RpcErrorCodeProto getRpcErrorCodeProto() { public RpcErrorCodeProto getRpcErrorCodeProto() {
return errCode; return errCode;
} }
@ -1487,19 +1487,6 @@ public abstract class Server {
} }
} }
/**
* A WrappedRpcServerException that is suppressed altogether
* for the purposes of logging.
*/
@SuppressWarnings("serial")
private static class WrappedRpcServerExceptionSuppressed
extends WrappedRpcServerException {
public WrappedRpcServerExceptionSuppressed(
RpcErrorCodeProto errCode, IOException ioe) {
super(errCode, ioe);
}
}
/** Reads calls from a connection and queues them for handling. */ /** Reads calls from a connection and queues them for handling. */
public class Connection { public class Connection {
private boolean connectionHeaderRead = false; // connection header is read? private boolean connectionHeaderRead = false; // connection header is read?
@ -1531,6 +1518,7 @@ public abstract class Server {
private ByteBuffer unwrappedData; private ByteBuffer unwrappedData;
private ByteBuffer unwrappedDataLengthBuffer; private ByteBuffer unwrappedDataLengthBuffer;
private int serviceClass; private int serviceClass;
private boolean shouldClose = false;
UserGroupInformation user = null; UserGroupInformation user = null;
public UserGroupInformation attemptingUser = null; // user name before auth public UserGroupInformation attemptingUser = null; // user name before auth
@ -1573,6 +1561,14 @@ public abstract class Server {
return getHostAddress() + ":" + remotePort; return getHostAddress() + ":" + remotePort;
} }
boolean setShouldClose() {
return shouldClose = true;
}
boolean shouldClose() {
return shouldClose;
}
public String getHostAddress() { public String getHostAddress() {
return hostAddress; return hostAddress;
} }
@ -1622,13 +1618,13 @@ public abstract class Server {
} }
private void saslReadAndProcess(RpcWritable.Buffer buffer) throws private void saslReadAndProcess(RpcWritable.Buffer buffer) throws
WrappedRpcServerException, IOException, InterruptedException { RpcServerException, IOException, InterruptedException {
final RpcSaslProto saslMessage = final RpcSaslProto saslMessage =
getMessage(RpcSaslProto.getDefaultInstance(), buffer); getMessage(RpcSaslProto.getDefaultInstance(), buffer);
switch (saslMessage.getState()) { switch (saslMessage.getState()) {
case WRAP: { case WRAP: {
if (!saslContextEstablished || !useWrap) { if (!saslContextEstablished || !useWrap) {
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
new SaslException("Server is not wrapping data")); new SaslException("Server is not wrapping data"));
} }
@ -1665,9 +1661,9 @@ public abstract class Server {
} }
private void saslProcess(RpcSaslProto saslMessage) private void saslProcess(RpcSaslProto saslMessage)
throws WrappedRpcServerException, IOException, InterruptedException { throws RpcServerException, IOException, InterruptedException {
if (saslContextEstablished) { if (saslContextEstablished) {
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
new SaslException("Negotiation is already complete")); new SaslException("Negotiation is already complete"));
} }
@ -1701,10 +1697,10 @@ public abstract class Server {
AUDITLOG.info(AUTH_SUCCESSFUL_FOR + user); AUDITLOG.info(AUTH_SUCCESSFUL_FOR + user);
saslContextEstablished = true; saslContextEstablished = true;
} }
} catch (WrappedRpcServerException wrse) { // don't re-wrap } catch (RpcServerException rse) { // don't re-wrap
throw wrse; throw rse;
} catch (IOException ioe) { } catch (IOException ioe) {
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_UNAUTHORIZED, ioe); RpcErrorCodeProto.FATAL_UNAUTHORIZED, ioe);
} }
// send back response if any, may throw IOException // send back response if any, may throw IOException
@ -1822,14 +1818,14 @@ public abstract class Server {
setupResponse(saslCall, setupResponse(saslCall,
RpcStatusProto.SUCCESS, null, RpcStatusProto.SUCCESS, null,
RpcWritable.wrap(message), null, null); RpcWritable.wrap(message), null, null);
saslCall.sendResponse(); sendResponse(saslCall);
} }
private void doSaslReply(Exception ioe) throws IOException { private void doSaslReply(Exception ioe) throws IOException {
setupResponse(authFailedCall, setupResponse(authFailedCall,
RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_UNAUTHORIZED, RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_UNAUTHORIZED,
null, ioe.getClass().getName(), ioe.getLocalizedMessage()); null, ioe.getClass().getName(), ioe.getLocalizedMessage());
authFailedCall.sendResponse(); sendResponse(authFailedCall);
} }
private void disposeSasl() { private void disposeSasl() {
@ -1858,12 +1854,8 @@ public abstract class Server {
} }
} }
public int readAndProcess() public int readAndProcess() throws IOException, InterruptedException {
throws WrappedRpcServerException, IOException, InterruptedException { while (!shouldClose()) { // stop if a fatal response has been sent.
while (true) {
/* Read at most one RPC. If the header is not read completely yet
* then iterate until we read first RPC or until there is no data left.
*/
int count = -1; int count = -1;
if (dataLengthBuffer.remaining() > 0) { if (dataLengthBuffer.remaining() > 0) {
count = channelRead(channel, dataLengthBuffer); count = channelRead(channel, dataLengthBuffer);
@ -1925,15 +1917,17 @@ public abstract class Server {
if (data.remaining() == 0) { if (data.remaining() == 0) {
dataLengthBuffer.clear(); dataLengthBuffer.clear();
data.flip(); data.flip();
ByteBuffer requestData = data;
data = null; // null out in case processOneRpc throws.
boolean isHeaderRead = connectionContextRead; boolean isHeaderRead = connectionContextRead;
processOneRpc(data); processOneRpc(requestData);
data = null;
if (!isHeaderRead) { if (!isHeaderRead) {
continue; continue;
} }
} }
return count; return count;
} }
return -1;
} }
private AuthProtocol initializeAuthContext(int authType) private AuthProtocol initializeAuthContext(int authType)
@ -2008,14 +2002,14 @@ public abstract class Server {
setupResponse(fakeCall, setupResponse(fakeCall,
RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_VERSION_MISMATCH, RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_VERSION_MISMATCH,
null, VersionMismatch.class.getName(), errMsg); null, VersionMismatch.class.getName(), errMsg);
fakeCall.sendResponse(); sendResponse(fakeCall);
} else if (clientVersion >= 3) { } else if (clientVersion >= 3) {
RpcCall fakeCall = new RpcCall(this, -1); RpcCall fakeCall = new RpcCall(this, -1);
// Versions 3 to 8 use older response // Versions 3 to 8 use older response
setupResponseOldVersionFatal(buffer, fakeCall, setupResponseOldVersionFatal(buffer, fakeCall,
null, VersionMismatch.class.getName(), errMsg); null, VersionMismatch.class.getName(), errMsg);
fakeCall.sendResponse(); sendResponse(fakeCall);
} else if (clientVersion == 2) { // Hadoop 0.18.3 } else if (clientVersion == 2) { // Hadoop 0.18.3
RpcCall fakeCall = new RpcCall(this, 0); RpcCall fakeCall = new RpcCall(this, 0);
DataOutputStream out = new DataOutputStream(buffer); DataOutputStream out = new DataOutputStream(buffer);
@ -2024,7 +2018,7 @@ public abstract class Server {
WritableUtils.writeString(out, VersionMismatch.class.getName()); WritableUtils.writeString(out, VersionMismatch.class.getName());
WritableUtils.writeString(out, errMsg); WritableUtils.writeString(out, errMsg);
fakeCall.setResponse(ByteBuffer.wrap(buffer.toByteArray())); fakeCall.setResponse(ByteBuffer.wrap(buffer.toByteArray()));
fakeCall.sendResponse(); sendResponse(fakeCall);
} }
} }
@ -2032,19 +2026,19 @@ public abstract class Server {
RpcCall fakeCall = new RpcCall(this, 0); RpcCall fakeCall = new RpcCall(this, 0);
fakeCall.setResponse(ByteBuffer.wrap( fakeCall.setResponse(ByteBuffer.wrap(
RECEIVED_HTTP_REQ_RESPONSE.getBytes(StandardCharsets.UTF_8))); RECEIVED_HTTP_REQ_RESPONSE.getBytes(StandardCharsets.UTF_8)));
fakeCall.sendResponse(); sendResponse(fakeCall);
} }
/** Reads the connection context following the connection header /** Reads the connection context following the connection header
* @param buffer - DataInputStream from which to read the header * @param buffer - DataInputStream from which to read the header
* @throws WrappedRpcServerException - if the header cannot be * @throws RpcServerException - if the header cannot be
* deserialized, or the user is not authorized * deserialized, or the user is not authorized
*/ */
private void processConnectionContext(RpcWritable.Buffer buffer) private void processConnectionContext(RpcWritable.Buffer buffer)
throws WrappedRpcServerException { throws RpcServerException {
// allow only one connection context during a session // allow only one connection context during a session
if (connectionContextRead) { if (connectionContextRead) {
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
"Connection context already processed"); "Connection context already processed");
} }
@ -2065,7 +2059,7 @@ public abstract class Server {
&& (!protocolUser.getUserName().equals(user.getUserName()))) { && (!protocolUser.getUserName().equals(user.getUserName()))) {
if (authMethod == AuthMethod.TOKEN) { if (authMethod == AuthMethod.TOKEN) {
// Not allowed to doAs if token authentication is used // Not allowed to doAs if token authentication is used
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_UNAUTHORIZED, RpcErrorCodeProto.FATAL_UNAUTHORIZED,
new AccessControlException("Authenticated user (" + user new AccessControlException("Authenticated user (" + user
+ ") doesn't match what the client claims to be (" + ") doesn't match what the client claims to be ("
@ -2096,7 +2090,7 @@ public abstract class Server {
* @throws InterruptedException * @throws InterruptedException
*/ */
private void unwrapPacketAndProcessRpcs(byte[] inBuf) private void unwrapPacketAndProcessRpcs(byte[] inBuf)
throws WrappedRpcServerException, IOException, InterruptedException { throws IOException, InterruptedException {
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("Have read input token of size " + inBuf.length LOG.debug("Have read input token of size " + inBuf.length
+ " for processing by saslServer.unwrap()"); + " for processing by saslServer.unwrap()");
@ -2105,7 +2099,7 @@ public abstract class Server {
ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream( ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(
inBuf)); inBuf));
// Read all RPCs contained in the inBuf, even partial ones // Read all RPCs contained in the inBuf, even partial ones
while (true) { while (!shouldClose()) { // stop if a fatal response has been sent.
int count = -1; int count = -1;
if (unwrappedDataLengthBuffer.remaining() > 0) { if (unwrappedDataLengthBuffer.remaining() > 0) {
count = channelRead(ch, unwrappedDataLengthBuffer); count = channelRead(ch, unwrappedDataLengthBuffer);
@ -2126,25 +2120,34 @@ public abstract class Server {
if (unwrappedData.remaining() == 0) { if (unwrappedData.remaining() == 0) {
unwrappedDataLengthBuffer.clear(); unwrappedDataLengthBuffer.clear();
unwrappedData.flip(); unwrappedData.flip();
processOneRpc(unwrappedData); ByteBuffer requestData = unwrappedData;
unwrappedData = null; unwrappedData = null; // null out in case processOneRpc throws.
processOneRpc(requestData);
} }
} }
} }
/** /**
* Process an RPC Request - handle connection setup and decoding of * Process one RPC Request from buffer read from socket stream
* request into a Call * - decode rpc in a rpc-Call
* - handle out-of-band RPC requests such as the initial connectionContext
* - A successfully decoded RpcCall will be deposited in RPC-Q and
* its response will be sent later when the request is processed.
*
* Prior to this call the connectionHeader ("hrpc...") has been handled and
* if SASL then SASL has been established and the buf we are passed
* has been unwrapped from SASL.
*
* @param bb - contains the RPC request header and the rpc request * @param bb - contains the RPC request header and the rpc request
* @throws IOException - internal error that should not be returned to * @throws IOException - internal error that should not be returned to
* client, typically failure to respond to client * client, typically failure to respond to client
* @throws WrappedRpcServerException - an exception to be sent back to
* the client that does not require verbose logging by the
* Listener thread
* @throws InterruptedException * @throws InterruptedException
*/ */
private void processOneRpc(ByteBuffer bb) private void processOneRpc(ByteBuffer bb)
throws IOException, WrappedRpcServerException, InterruptedException { throws IOException, InterruptedException {
// exceptions that escape this method are fatal to the connection.
// setupResponse will use the rpc status to determine if the connection
// should be closed.
int callId = -1; int callId = -1;
int retry = RpcConstants.INVALID_RETRY_COUNT; int retry = RpcConstants.INVALID_RETRY_COUNT;
try { try {
@ -2161,40 +2164,47 @@ public abstract class Server {
if (callId < 0) { // callIds typically used during connection setup if (callId < 0) { // callIds typically used during connection setup
processRpcOutOfBandRequest(header, buffer); processRpcOutOfBandRequest(header, buffer);
} else if (!connectionContextRead) { } else if (!connectionContextRead) {
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
"Connection context not established"); "Connection context not established");
} else { } else {
processRpcRequest(header, buffer); processRpcRequest(header, buffer);
} }
} catch (WrappedRpcServerException wrse) { // inform client of error } catch (RpcServerException rse) {
Throwable ioe = wrse.getCause(); // inform client of error, but do not rethrow else non-fatal
// exceptions will close connection!
if (LOG.isDebugEnabled()) {
LOG.debug(Thread.currentThread().getName() +
": processOneRpc from client " + this +
" threw exception [" + rse + "]");
}
// use the wrapped exception if there is one.
Throwable t = (rse.getCause() != null) ? rse.getCause() : rse;
final RpcCall call = new RpcCall(this, callId, retry); final RpcCall call = new RpcCall(this, callId, retry);
setupResponse(call, setupResponse(call,
RpcStatusProto.FATAL, wrse.getRpcErrorCodeProto(), null, rse.getRpcStatusProto(), rse.getRpcErrorCodeProto(), null,
ioe.getClass().getName(), ioe.getMessage()); t.getClass().getName(), t.getMessage());
call.sendResponse(); sendResponse(call);
throw wrse;
} }
} }
/** /**
* Verify RPC header is valid * Verify RPC header is valid
* @param header - RPC request header * @param header - RPC request header
* @throws WrappedRpcServerException - header contains invalid values * @throws RpcServerException - header contains invalid values
*/ */
private void checkRpcHeaders(RpcRequestHeaderProto header) private void checkRpcHeaders(RpcRequestHeaderProto header)
throws WrappedRpcServerException { throws RpcServerException {
if (!header.hasRpcOp()) { if (!header.hasRpcOp()) {
String err = " IPC Server: No rpc op in rpcRequestHeader"; String err = " IPC Server: No rpc op in rpcRequestHeader";
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err);
} }
if (header.getRpcOp() != if (header.getRpcOp() !=
RpcRequestHeaderProto.OperationProto.RPC_FINAL_PACKET) { RpcRequestHeaderProto.OperationProto.RPC_FINAL_PACKET) {
String err = "IPC Server does not implement rpc header operation" + String err = "IPC Server does not implement rpc header operation" +
header.getRpcOp(); header.getRpcOp();
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err);
} }
// If we know the rpc kind, get its class so that we can deserialize // If we know the rpc kind, get its class so that we can deserialize
@ -2202,7 +2212,7 @@ public abstract class Server {
// we continue with this original design. // we continue with this original design.
if (!header.hasRpcKind()) { if (!header.hasRpcKind()) {
String err = " IPC Server: No rpc kind in rpcRequestHeader"; String err = " IPC Server: No rpc kind in rpcRequestHeader";
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err);
} }
} }
@ -2212,13 +2222,13 @@ public abstract class Server {
* have been already read * have been already read
* @param header - RPC request header * @param header - RPC request header
* @param buffer - stream to request payload * @param buffer - stream to request payload
* @throws WrappedRpcServerException - due to fatal rpc layer issues such * @throws RpcServerException - generally due to fatal rpc layer issues
* as invalid header or deserialization error. In this case a RPC fatal * such as invalid header or deserialization error. The call queue
* status response will later be sent back to client. * may also throw a fatal or non-fatal exception on overflow.
* @throws InterruptedException * @throws InterruptedException
*/ */
private void processRpcRequest(RpcRequestHeaderProto header, private void processRpcRequest(RpcRequestHeaderProto header,
RpcWritable.Buffer buffer) throws WrappedRpcServerException, RpcWritable.Buffer buffer) throws RpcServerException,
InterruptedException { InterruptedException {
Class<? extends Writable> rpcRequestClass = Class<? extends Writable> rpcRequestClass =
getRpcRequestWrapper(header.getRpcKind()); getRpcRequestWrapper(header.getRpcKind());
@ -2227,18 +2237,20 @@ public abstract class Server {
" from client " + getHostAddress()); " from client " + getHostAddress());
final String err = "Unknown rpc kind in rpc header" + final String err = "Unknown rpc kind in rpc header" +
header.getRpcKind(); header.getRpcKind();
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err);
} }
Writable rpcRequest; Writable rpcRequest;
try { //Read the rpc request try { //Read the rpc request
rpcRequest = buffer.newInstance(rpcRequestClass, conf); rpcRequest = buffer.newInstance(rpcRequestClass, conf);
} catch (RpcServerException rse) { // lets tests inject failures.
throw rse;
} catch (Throwable t) { // includes runtime exception from newInstance } catch (Throwable t) { // includes runtime exception from newInstance
LOG.warn("Unable to read call parameters for client " + LOG.warn("Unable to read call parameters for client " +
getHostAddress() + "on connection protocol " + getHostAddress() + "on connection protocol " +
this.protocolName + " for rpcKind " + header.getRpcKind(), t); this.protocolName + " for rpcKind " + header.getRpcKind(), t);
String err = "IPC server unable to read call parameters: "+ t.getMessage(); String err = "IPC server unable to read call parameters: "+ t.getMessage();
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_DESERIALIZING_REQUEST, err); RpcErrorCodeProto.FATAL_DESERIALIZING_REQUEST, err);
} }
@ -2277,7 +2289,7 @@ public abstract class Server {
try { try {
queueCall(call); queueCall(call);
} catch (IOException ioe) { } catch (IOException ioe) {
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.ERROR_RPC_SERVER, ioe); RpcErrorCodeProto.ERROR_RPC_SERVER, ioe);
} }
incRpcCount(); // Increment the rpc count incRpcCount(); // Increment the rpc count
@ -2288,20 +2300,20 @@ public abstract class Server {
* reading and authorizing the connection header * reading and authorizing the connection header
* @param header - RPC header * @param header - RPC header
* @param buffer - stream to request payload * @param buffer - stream to request payload
* @throws WrappedRpcServerException - setup failed due to SASL * @throws RpcServerException - setup failed due to SASL
* negotiation failure, premature or invalid connection context, * negotiation failure, premature or invalid connection context,
* or other state errors * or other state errors
* @throws IOException - failed to send a response back to the client * @throws IOException - failed to send a response back to the client
* @throws InterruptedException * @throws InterruptedException
*/ */
private void processRpcOutOfBandRequest(RpcRequestHeaderProto header, private void processRpcOutOfBandRequest(RpcRequestHeaderProto header,
RpcWritable.Buffer buffer) throws WrappedRpcServerException, RpcWritable.Buffer buffer) throws RpcServerException,
IOException, InterruptedException { IOException, InterruptedException {
final int callId = header.getCallId(); final int callId = header.getCallId();
if (callId == CONNECTION_CONTEXT_CALL_ID) { if (callId == CONNECTION_CONTEXT_CALL_ID) {
// SASL must be established prior to connection context // SASL must be established prior to connection context
if (authProtocol == AuthProtocol.SASL && !saslContextEstablished) { if (authProtocol == AuthProtocol.SASL && !saslContextEstablished) {
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
"Connection header sent during SASL negotiation"); "Connection header sent during SASL negotiation");
} }
@ -2310,7 +2322,7 @@ public abstract class Server {
} else if (callId == AuthProtocol.SASL.callId) { } else if (callId == AuthProtocol.SASL.callId) {
// if client was switched to simple, ignore first SASL message // if client was switched to simple, ignore first SASL message
if (authProtocol != AuthProtocol.SASL) { if (authProtocol != AuthProtocol.SASL) {
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
"SASL protocol not requested by client"); "SASL protocol not requested by client");
} }
@ -2318,7 +2330,7 @@ public abstract class Server {
} else if (callId == PING_CALL_ID) { } else if (callId == PING_CALL_ID) {
LOG.debug("Received ping message"); LOG.debug("Received ping message");
} else { } else {
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
"Unknown out of band call #" + callId); "Unknown out of band call #" + callId);
} }
@ -2326,9 +2338,9 @@ public abstract class Server {
/** /**
* Authorize proxy users to access this server * Authorize proxy users to access this server
* @throws WrappedRpcServerException - user is not allowed to proxy * @throws RpcServerException - user is not allowed to proxy
*/ */
private void authorizeConnection() throws WrappedRpcServerException { private void authorizeConnection() throws RpcServerException {
try { try {
// If auth method is TOKEN, the token was obtained by the // If auth method is TOKEN, the token was obtained by the
// real user for the effective user, therefore not required to // real user for the effective user, therefore not required to
@ -2348,7 +2360,7 @@ public abstract class Server {
+ " for protocol " + connectionContext.getProtocol() + " for protocol " + connectionContext.getProtocol()
+ " is unauthorized for user " + user); + " is unauthorized for user " + user);
rpcMetrics.incrAuthorizationFailures(); rpcMetrics.incrAuthorizationFailures();
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_UNAUTHORIZED, ae); RpcErrorCodeProto.FATAL_UNAUTHORIZED, ae);
} }
} }
@ -2358,21 +2370,24 @@ public abstract class Server {
* @param message - Representation of the type of message * @param message - Representation of the type of message
* @param buffer - a buffer to read the protobuf * @param buffer - a buffer to read the protobuf
* @return Message - decoded protobuf * @return Message - decoded protobuf
* @throws WrappedRpcServerException - deserialization failed * @throws RpcServerException - deserialization failed
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
<T extends Message> T getMessage(Message message, <T extends Message> T getMessage(Message message,
RpcWritable.Buffer buffer) throws WrappedRpcServerException { RpcWritable.Buffer buffer) throws RpcServerException {
try { try {
return (T)buffer.getValue(message); return (T)buffer.getValue(message);
} catch (Exception ioe) { } catch (Exception ioe) {
Class<?> protoClass = message.getClass(); Class<?> protoClass = message.getClass();
throw new WrappedRpcServerException( throw new FatalRpcServerException(
RpcErrorCodeProto.FATAL_DESERIALIZING_REQUEST, RpcErrorCodeProto.FATAL_DESERIALIZING_REQUEST,
"Error decoding " + protoClass.getSimpleName() + ": "+ ioe); "Error decoding " + protoClass.getSimpleName() + ": "+ ioe);
} }
} }
// ipc reader threads should invoke this directly, whereas handlers
// must invoke call.sendResponse to allow lifecycle management of
// external, postponed, deferred calls, etc.
private void sendResponse(RpcCall call) throws IOException { private void sendResponse(RpcCall call) throws IOException {
responder.doRespond(call); responder.doRespond(call);
} }
@ -2676,6 +2691,10 @@ public abstract class Server {
RpcCall call, RpcStatusProto status, RpcErrorCodeProto erCode, RpcCall call, RpcStatusProto status, RpcErrorCodeProto erCode,
Writable rv, String errorClass, String error) Writable rv, String errorClass, String error)
throws IOException { throws IOException {
// fatal responses will cause the reader to close the connection.
if (status == RpcStatusProto.FATAL) {
call.connection.setShouldClose();
}
RpcResponseHeaderProto.Builder headerBuilder = RpcResponseHeaderProto.Builder headerBuilder =
RpcResponseHeaderProto.newBuilder(); RpcResponseHeaderProto.newBuilder();
headerBuilder.setClientId(ByteString.copyFrom(call.clientId)); headerBuilder.setClientId(ByteString.copyFrom(call.clientId));

View File

@ -31,9 +31,12 @@ import org.apache.hadoop.io.retry.RetryPolicy;
import org.apache.hadoop.io.retry.RetryProxy; import org.apache.hadoop.io.retry.RetryProxy;
import org.apache.hadoop.ipc.Client.ConnectionId; import org.apache.hadoop.ipc.Client.ConnectionId;
import org.apache.hadoop.ipc.Server.Call; import org.apache.hadoop.ipc.Server.Call;
import org.apache.hadoop.ipc.Server.Connection;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcErrorCodeProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcErrorCodeProto;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto;
import org.apache.hadoop.ipc.protobuf.TestProtos; import org.apache.hadoop.ipc.protobuf.TestProtos;
import org.apache.hadoop.metrics2.MetricsRecordBuilder; import org.apache.hadoop.metrics2.MetricsRecordBuilder;
import org.apache.hadoop.metrics2.lib.MutableCounterLong;
import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.AccessControlException; import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.SecurityUtil;
@ -64,6 +67,7 @@ import java.net.ConnectException;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketTimeoutException; import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.security.PrivilegedAction; import java.security.PrivilegedAction;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.util.ArrayList; import java.util.ArrayList;
@ -77,6 +81,7 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -85,6 +90,10 @@ import static org.apache.hadoop.test.MetricsAsserts.assertCounterGt;
import static org.apache.hadoop.test.MetricsAsserts.getLongCounter; import static org.apache.hadoop.test.MetricsAsserts.getLongCounter;
import static org.apache.hadoop.test.MetricsAsserts.getMetrics; import static org.apache.hadoop.test.MetricsAsserts.getMetrics;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
@ -1365,6 +1374,116 @@ public class TestRPC extends TestRpcBase {
} }
} }
public static class FakeRequestClass extends RpcWritable {
static volatile IOException exception;
@Override
void writeTo(ResponseBuffer out) throws IOException {
throw new UnsupportedOperationException();
}
@Override
<T> T readFrom(ByteBuffer bb) throws IOException {
throw exception;
}
}
@SuppressWarnings("serial")
public static class TestReaderException extends IOException {
public TestReaderException(String msg) {
super(msg);
}
@Override
public boolean equals(Object t) {
return (t.getClass() == TestReaderException.class) &&
getMessage().equals(((TestReaderException)t).getMessage());
}
}
@Test (timeout=30000)
public void testReaderExceptions() throws Exception {
Server server = null;
TestRpcService proxy = null;
// will attempt to return this exception from a reader with and w/o
// the connection closing.
IOException expectedIOE = new TestReaderException("testing123");
@SuppressWarnings("serial")
IOException rseError = new RpcServerException("keepalive", expectedIOE){
@Override
public RpcStatusProto getRpcStatusProto() {
return RpcStatusProto.ERROR;
}
};
@SuppressWarnings("serial")
IOException rseFatal = new RpcServerException("disconnect", expectedIOE) {
@Override
public RpcStatusProto getRpcStatusProto() {
return RpcStatusProto.FATAL;
}
};
try {
RPC.Builder builder = newServerBuilder(conf)
.setQueueSizePerHandler(1).setNumHandlers(1).setVerbose(true);
server = setupTestServer(builder);
Whitebox.setInternalState(
server, "rpcRequestClass", FakeRequestClass.class);
MutableCounterLong authMetric =
(MutableCounterLong)Whitebox.getInternalState(
server.getRpcMetrics(), "rpcAuthorizationSuccesses");
proxy = getClient(addr, conf);
boolean isDisconnected = true;
Connection lastConn = null;
long expectedAuths = 0;
// fuzz the client.
for (int i=0; i < 128; i++) {
String reqName = "request[" + i + "]";
int r = ThreadLocalRandom.current().nextInt();
final boolean doDisconnect = r % 4 == 0;
LOG.info("TestDisconnect request[" + i + "] " +
" shouldConnect=" + isDisconnected +
" willDisconnect=" + doDisconnect);
if (isDisconnected) {
expectedAuths++;
}
try {
FakeRequestClass.exception = doDisconnect ? rseFatal : rseError;
proxy.ping(null, newEmptyRequest());
fail(reqName + " didn't fail");
} catch (ServiceException e) {
RemoteException re = (RemoteException)e.getCause();
assertEquals(reqName, expectedIOE, re.unwrapRemoteException());
}
// check authorizations to ensure new connection when expected,
// then conclusively determine if connections are disconnected
// correctly.
assertEquals(reqName, expectedAuths, authMetric.value());
if (!doDisconnect) {
// if it wasn't fatal, verify there's only one open connection.
Connection[] conns = server.getConnections();
assertEquals(reqName, 1, conns.length);
// verify whether the connection should have been reused.
if (isDisconnected) {
assertNotSame(reqName, lastConn, conns[0]);
} else {
assertSame(reqName, lastConn, conns[0]);
}
lastConn = conns[0];
} else if (lastConn != null) {
// avoid race condition in server where connection may not be
// fully removed yet. just make sure it's marked for being closed.
// the open connection checks above ensure correct behavior.
assertTrue(reqName, lastConn.shouldClose());
}
isDisconnected = doDisconnect;
}
} finally {
stop(server, proxy);
}
}
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
new TestRPC().testCallsInternal(conf); new TestRPC().testCallsInternal(conf);
} }