From 47bbe431bf85727357feb85650c1b57a3e99f113 Mon Sep 17 00:00:00 2001 From: Kihwal Lee Date: Thu, 9 Feb 2017 11:04:29 -0600 Subject: [PATCH] HADOOP-14034. Allow ipc layer exceptions to selectively close connections. Contributed by Daryn Sharp. (cherry picked from commit d008b5515304b42faeb48e542c8c27586b8564eb) Conflicts: hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java --- .../java/org/apache/hadoop/ipc/Server.java | 213 ++++++++++-------- .../java/org/apache/hadoop/ipc/TestRPC.java | 119 ++++++++++ 2 files changed, 235 insertions(+), 97 deletions(-) diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java index c99d553e937..ccdd7768a8d 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java @@ -1127,20 +1127,16 @@ void doRead(SelectionKey key) throws InterruptedException { LOG.info(Thread.currentThread().getName() + ": readAndProcess caught InterruptedException", ieo); throw ieo; } catch (Exception e) { - // Do not log WrappedRpcServerExceptionSuppressed. - if (!(e instanceof WrappedRpcServerExceptionSuppressed)) { - // A WrappedRpcServerException is an exception that has been sent - // to the client, so the stacktrace is unnecessary; any other - // exceptions are unexpected internal server errors and thus the - // stacktrace should be logged. - LOG.info(Thread.currentThread().getName() + - ": readAndProcess from client " + c.getHostAddress() + - " threw exception [" + e + "]", - (e instanceof WrappedRpcServerException) ? null : e); - } + // Any exceptions that reach here are fatal unexpected internal errors + // that could not be sent to the client. + LOG.info(Thread.currentThread().getName() + + ": readAndProcess from client " + c + + " threw exception [" + e + "]", e); count = -1; //so that the (count < 0) block is executed } - if (count < 0) { + // setupResponse will signal the connection should be closed when a + // fatal response is sent. + if (count < 0 || c.shouldClose()) { closeConnection(c); c = null; } @@ -1468,16 +1464,20 @@ static AuthProtocol valueOf(int callId) { * unnecessary stack trace logging if it's not an internal server error. */ @SuppressWarnings("serial") - private static class WrappedRpcServerException extends RpcServerException { + private static class FatalRpcServerException extends RpcServerException { private final RpcErrorCodeProto errCode; - public WrappedRpcServerException(RpcErrorCodeProto errCode, IOException ioe) { + public FatalRpcServerException(RpcErrorCodeProto errCode, IOException ioe) { super(ioe.toString(), ioe); this.errCode = errCode; } - public WrappedRpcServerException(RpcErrorCodeProto errCode, String message) { + public FatalRpcServerException(RpcErrorCodeProto errCode, String message) { this(errCode, new RpcServerException(message)); } @Override + public RpcStatusProto getRpcStatusProto() { + return RpcStatusProto.FATAL; + } + @Override public RpcErrorCodeProto getRpcErrorCodeProto() { return errCode; } @@ -1487,19 +1487,6 @@ public String toString() { } } - /** - * A WrappedRpcServerException that is suppressed altogether - * for the purposes of logging. - */ - @SuppressWarnings("serial") - private static class WrappedRpcServerExceptionSuppressed - extends WrappedRpcServerException { - public WrappedRpcServerExceptionSuppressed( - RpcErrorCodeProto errCode, IOException ioe) { - super(errCode, ioe); - } - } - /** Reads calls from a connection and queues them for handling. */ public class Connection { private boolean connectionHeaderRead = false; // connection header is read? @@ -1531,7 +1518,8 @@ public class Connection { private ByteBuffer unwrappedData; private ByteBuffer unwrappedDataLengthBuffer; private int serviceClass; - + private boolean shouldClose = false; + UserGroupInformation user = null; public UserGroupInformation attemptingUser = null; // user name before auth @@ -1572,7 +1560,15 @@ public Connection(SocketChannel channel, long lastContact) { public String toString() { return getHostAddress() + ":" + remotePort; } - + + boolean setShouldClose() { + return shouldClose = true; + } + + boolean shouldClose() { + return shouldClose; + } + public String getHostAddress() { return hostAddress; } @@ -1622,13 +1618,13 @@ private UserGroupInformation getAuthorizedUgi(String authorizedId) } private void saslReadAndProcess(RpcWritable.Buffer buffer) throws - WrappedRpcServerException, IOException, InterruptedException { + RpcServerException, IOException, InterruptedException { final RpcSaslProto saslMessage = getMessage(RpcSaslProto.getDefaultInstance(), buffer); switch (saslMessage.getState()) { case WRAP: { if (!saslContextEstablished || !useWrap) { - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, new SaslException("Server is not wrapping data")); } @@ -1663,11 +1659,11 @@ private Throwable getCauseForInvalidToken(IOException e) { } return e; } - + private void saslProcess(RpcSaslProto saslMessage) - throws WrappedRpcServerException, IOException, InterruptedException { + throws RpcServerException, IOException, InterruptedException { if (saslContextEstablished) { - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, new SaslException("Negotiation is already complete")); } @@ -1701,10 +1697,10 @@ private void saslProcess(RpcSaslProto saslMessage) AUDITLOG.info(AUTH_SUCCESSFUL_FOR + user); saslContextEstablished = true; } - } catch (WrappedRpcServerException wrse) { // don't re-wrap - throw wrse; + } catch (RpcServerException rse) { // don't re-wrap + throw rse; } catch (IOException ioe) { - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_UNAUTHORIZED, ioe); } // send back response if any, may throw IOException @@ -1822,14 +1818,14 @@ private void doSaslReply(Message message) throws IOException { setupResponse(saslCall, RpcStatusProto.SUCCESS, null, RpcWritable.wrap(message), null, null); - saslCall.sendResponse(); + sendResponse(saslCall); } private void doSaslReply(Exception ioe) throws IOException { setupResponse(authFailedCall, RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_UNAUTHORIZED, null, ioe.getClass().getName(), ioe.getLocalizedMessage()); - authFailedCall.sendResponse(); + sendResponse(authFailedCall); } private void disposeSasl() { @@ -1858,12 +1854,8 @@ private void checkDataLength(int dataLength) throws IOException { } } - public int readAndProcess() - throws WrappedRpcServerException, IOException, InterruptedException { - while (true) { - /* Read at most one RPC. If the header is not read completely yet - * then iterate until we read first RPC or until there is no data left. - */ + public int readAndProcess() throws IOException, InterruptedException { + while (!shouldClose()) { // stop if a fatal response has been sent. int count = -1; if (dataLengthBuffer.remaining() > 0) { count = channelRead(channel, dataLengthBuffer); @@ -1925,15 +1917,17 @@ public int readAndProcess() if (data.remaining() == 0) { dataLengthBuffer.clear(); data.flip(); + ByteBuffer requestData = data; + data = null; // null out in case processOneRpc throws. boolean isHeaderRead = connectionContextRead; - processOneRpc(data); - data = null; + processOneRpc(requestData); if (!isHeaderRead) { continue; } } return count; } + return -1; } private AuthProtocol initializeAuthContext(int authType) @@ -2008,14 +2002,14 @@ private void setupBadVersionResponse(int clientVersion) throws IOException { setupResponse(fakeCall, RpcStatusProto.FATAL, RpcErrorCodeProto.FATAL_VERSION_MISMATCH, null, VersionMismatch.class.getName(), errMsg); - fakeCall.sendResponse(); + sendResponse(fakeCall); } else if (clientVersion >= 3) { RpcCall fakeCall = new RpcCall(this, -1); // Versions 3 to 8 use older response setupResponseOldVersionFatal(buffer, fakeCall, null, VersionMismatch.class.getName(), errMsg); - fakeCall.sendResponse(); + sendResponse(fakeCall); } else if (clientVersion == 2) { // Hadoop 0.18.3 RpcCall fakeCall = new RpcCall(this, 0); DataOutputStream out = new DataOutputStream(buffer); @@ -2024,7 +2018,7 @@ private void setupBadVersionResponse(int clientVersion) throws IOException { WritableUtils.writeString(out, VersionMismatch.class.getName()); WritableUtils.writeString(out, errMsg); fakeCall.setResponse(ByteBuffer.wrap(buffer.toByteArray())); - fakeCall.sendResponse(); + sendResponse(fakeCall); } } @@ -2032,19 +2026,19 @@ private void setupHttpRequestOnIpcPortResponse() throws IOException { RpcCall fakeCall = new RpcCall(this, 0); fakeCall.setResponse(ByteBuffer.wrap( RECEIVED_HTTP_REQ_RESPONSE.getBytes(StandardCharsets.UTF_8))); - fakeCall.sendResponse(); + sendResponse(fakeCall); } /** Reads the connection context following the connection header * @param buffer - DataInputStream from which to read the header - * @throws WrappedRpcServerException - if the header cannot be + * @throws RpcServerException - if the header cannot be * deserialized, or the user is not authorized */ private void processConnectionContext(RpcWritable.Buffer buffer) - throws WrappedRpcServerException { + throws RpcServerException { // allow only one connection context during a session if (connectionContextRead) { - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, "Connection context already processed"); } @@ -2065,7 +2059,7 @@ private void processConnectionContext(RpcWritable.Buffer buffer) && (!protocolUser.getUserName().equals(user.getUserName()))) { if (authMethod == AuthMethod.TOKEN) { // Not allowed to doAs if token authentication is used - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_UNAUTHORIZED, new AccessControlException("Authenticated user (" + user + ") doesn't match what the client claims to be (" @@ -2096,7 +2090,7 @@ private void processConnectionContext(RpcWritable.Buffer buffer) * @throws InterruptedException */ private void unwrapPacketAndProcessRpcs(byte[] inBuf) - throws WrappedRpcServerException, IOException, InterruptedException { + throws IOException, InterruptedException { if (LOG.isDebugEnabled()) { LOG.debug("Have read input token of size " + inBuf.length + " for processing by saslServer.unwrap()"); @@ -2105,7 +2099,7 @@ private void unwrapPacketAndProcessRpcs(byte[] inBuf) ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream( inBuf)); // Read all RPCs contained in the inBuf, even partial ones - while (true) { + while (!shouldClose()) { // stop if a fatal response has been sent. int count = -1; if (unwrappedDataLengthBuffer.remaining() > 0) { count = channelRead(ch, unwrappedDataLengthBuffer); @@ -2126,25 +2120,34 @@ private void unwrapPacketAndProcessRpcs(byte[] inBuf) if (unwrappedData.remaining() == 0) { unwrappedDataLengthBuffer.clear(); unwrappedData.flip(); - processOneRpc(unwrappedData); - unwrappedData = null; + ByteBuffer requestData = unwrappedData; + unwrappedData = null; // null out in case processOneRpc throws. + processOneRpc(requestData); } } } /** - * Process an RPC Request - handle connection setup and decoding of - * request into a Call + * Process one RPC Request from buffer read from socket stream + * - decode rpc in a rpc-Call + * - handle out-of-band RPC requests such as the initial connectionContext + * - A successfully decoded RpcCall will be deposited in RPC-Q and + * its response will be sent later when the request is processed. + * + * Prior to this call the connectionHeader ("hrpc...") has been handled and + * if SASL then SASL has been established and the buf we are passed + * has been unwrapped from SASL. + * * @param bb - contains the RPC request header and the rpc request * @throws IOException - internal error that should not be returned to * client, typically failure to respond to client - * @throws WrappedRpcServerException - an exception to be sent back to - * the client that does not require verbose logging by the - * Listener thread * @throws InterruptedException */ private void processOneRpc(ByteBuffer bb) - throws IOException, WrappedRpcServerException, InterruptedException { + throws IOException, InterruptedException { + // exceptions that escape this method are fatal to the connection. + // setupResponse will use the rpc status to determine if the connection + // should be closed. int callId = -1; int retry = RpcConstants.INVALID_RETRY_COUNT; try { @@ -2161,40 +2164,47 @@ private void processOneRpc(ByteBuffer bb) if (callId < 0) { // callIds typically used during connection setup processRpcOutOfBandRequest(header, buffer); } else if (!connectionContextRead) { - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, "Connection context not established"); } else { processRpcRequest(header, buffer); } - } catch (WrappedRpcServerException wrse) { // inform client of error - Throwable ioe = wrse.getCause(); + } catch (RpcServerException rse) { + // inform client of error, but do not rethrow else non-fatal + // exceptions will close connection! + if (LOG.isDebugEnabled()) { + LOG.debug(Thread.currentThread().getName() + + ": processOneRpc from client " + this + + " threw exception [" + rse + "]"); + } + // use the wrapped exception if there is one. + Throwable t = (rse.getCause() != null) ? rse.getCause() : rse; final RpcCall call = new RpcCall(this, callId, retry); setupResponse(call, - RpcStatusProto.FATAL, wrse.getRpcErrorCodeProto(), null, - ioe.getClass().getName(), ioe.getMessage()); - call.sendResponse(); - throw wrse; + rse.getRpcStatusProto(), rse.getRpcErrorCodeProto(), null, + t.getClass().getName(), t.getMessage()); + sendResponse(call); } } /** * Verify RPC header is valid * @param header - RPC request header - * @throws WrappedRpcServerException - header contains invalid values + * @throws RpcServerException - header contains invalid values */ private void checkRpcHeaders(RpcRequestHeaderProto header) - throws WrappedRpcServerException { + throws RpcServerException { if (!header.hasRpcOp()) { String err = " IPC Server: No rpc op in rpcRequestHeader"; - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); } if (header.getRpcOp() != RpcRequestHeaderProto.OperationProto.RPC_FINAL_PACKET) { String err = "IPC Server does not implement rpc header operation" + header.getRpcOp(); - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); } // If we know the rpc kind, get its class so that we can deserialize @@ -2202,7 +2212,7 @@ private void checkRpcHeaders(RpcRequestHeaderProto header) // we continue with this original design. if (!header.hasRpcKind()) { String err = " IPC Server: No rpc kind in rpcRequestHeader"; - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); } } @@ -2212,13 +2222,13 @@ private void checkRpcHeaders(RpcRequestHeaderProto header) * have been already read * @param header - RPC request header * @param buffer - stream to request payload - * @throws WrappedRpcServerException - due to fatal rpc layer issues such - * as invalid header or deserialization error. In this case a RPC fatal - * status response will later be sent back to client. + * @throws RpcServerException - generally due to fatal rpc layer issues + * such as invalid header or deserialization error. The call queue + * may also throw a fatal or non-fatal exception on overflow. * @throws InterruptedException */ private void processRpcRequest(RpcRequestHeaderProto header, - RpcWritable.Buffer buffer) throws WrappedRpcServerException, + RpcWritable.Buffer buffer) throws RpcServerException, InterruptedException { Class rpcRequestClass = getRpcRequestWrapper(header.getRpcKind()); @@ -2227,18 +2237,20 @@ private void processRpcRequest(RpcRequestHeaderProto header, " from client " + getHostAddress()); final String err = "Unknown rpc kind in rpc header" + header.getRpcKind(); - throw new WrappedRpcServerException( - RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); + throw new FatalRpcServerException( + RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); } Writable rpcRequest; try { //Read the rpc request rpcRequest = buffer.newInstance(rpcRequestClass, conf); + } catch (RpcServerException rse) { // lets tests inject failures. + throw rse; } catch (Throwable t) { // includes runtime exception from newInstance LOG.warn("Unable to read call parameters for client " + getHostAddress() + "on connection protocol " + this.protocolName + " for rpcKind " + header.getRpcKind(), t); String err = "IPC server unable to read call parameters: "+ t.getMessage(); - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_DESERIALIZING_REQUEST, err); } @@ -2277,7 +2289,7 @@ private void processRpcRequest(RpcRequestHeaderProto header, try { queueCall(call); } catch (IOException ioe) { - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.ERROR_RPC_SERVER, ioe); } incRpcCount(); // Increment the rpc count @@ -2288,20 +2300,20 @@ private void processRpcRequest(RpcRequestHeaderProto header, * reading and authorizing the connection header * @param header - RPC header * @param buffer - stream to request payload - * @throws WrappedRpcServerException - setup failed due to SASL + * @throws RpcServerException - setup failed due to SASL * negotiation failure, premature or invalid connection context, * or other state errors * @throws IOException - failed to send a response back to the client * @throws InterruptedException */ private void processRpcOutOfBandRequest(RpcRequestHeaderProto header, - RpcWritable.Buffer buffer) throws WrappedRpcServerException, + RpcWritable.Buffer buffer) throws RpcServerException, IOException, InterruptedException { final int callId = header.getCallId(); if (callId == CONNECTION_CONTEXT_CALL_ID) { // SASL must be established prior to connection context if (authProtocol == AuthProtocol.SASL && !saslContextEstablished) { - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, "Connection header sent during SASL negotiation"); } @@ -2310,7 +2322,7 @@ private void processRpcOutOfBandRequest(RpcRequestHeaderProto header, } else if (callId == AuthProtocol.SASL.callId) { // if client was switched to simple, ignore first SASL message if (authProtocol != AuthProtocol.SASL) { - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, "SASL protocol not requested by client"); } @@ -2318,7 +2330,7 @@ private void processRpcOutOfBandRequest(RpcRequestHeaderProto header, } else if (callId == PING_CALL_ID) { LOG.debug("Received ping message"); } else { - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, "Unknown out of band call #" + callId); } @@ -2326,9 +2338,9 @@ private void processRpcOutOfBandRequest(RpcRequestHeaderProto header, /** * Authorize proxy users to access this server - * @throws WrappedRpcServerException - user is not allowed to proxy + * @throws RpcServerException - user is not allowed to proxy */ - private void authorizeConnection() throws WrappedRpcServerException { + private void authorizeConnection() throws RpcServerException { try { // If auth method is TOKEN, the token was obtained by the // real user for the effective user, therefore not required to @@ -2348,7 +2360,7 @@ private void authorizeConnection() throws WrappedRpcServerException { + " for protocol " + connectionContext.getProtocol() + " is unauthorized for user " + user); rpcMetrics.incrAuthorizationFailures(); - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_UNAUTHORIZED, ae); } } @@ -2358,21 +2370,24 @@ private void authorizeConnection() throws WrappedRpcServerException { * @param message - Representation of the type of message * @param buffer - a buffer to read the protobuf * @return Message - decoded protobuf - * @throws WrappedRpcServerException - deserialization failed + * @throws RpcServerException - deserialization failed */ @SuppressWarnings("unchecked") T getMessage(Message message, - RpcWritable.Buffer buffer) throws WrappedRpcServerException { + RpcWritable.Buffer buffer) throws RpcServerException { try { return (T)buffer.getValue(message); } catch (Exception ioe) { Class protoClass = message.getClass(); - throw new WrappedRpcServerException( + throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_DESERIALIZING_REQUEST, "Error decoding " + protoClass.getSimpleName() + ": "+ ioe); } } + // ipc reader threads should invoke this directly, whereas handlers + // must invoke call.sendResponse to allow lifecycle management of + // external, postponed, deferred calls, etc. private void sendResponse(RpcCall call) throws IOException { responder.doRespond(call); } @@ -2676,6 +2691,10 @@ private void setupResponse( RpcCall call, RpcStatusProto status, RpcErrorCodeProto erCode, Writable rv, String errorClass, String error) throws IOException { + // fatal responses will cause the reader to close the connection. + if (status == RpcStatusProto.FATAL) { + call.connection.setShouldClose(); + } RpcResponseHeaderProto.Builder headerBuilder = RpcResponseHeaderProto.newBuilder(); headerBuilder.setClientId(ByteString.copyFrom(call.clientId)); diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestRPC.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestRPC.java index 3b8ba03b09c..ecc71b69ba7 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestRPC.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestRPC.java @@ -31,9 +31,12 @@ import org.apache.hadoop.io.retry.RetryProxy; import org.apache.hadoop.ipc.Client.ConnectionId; import org.apache.hadoop.ipc.Server.Call; +import org.apache.hadoop.ipc.Server.Connection; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcErrorCodeProto; +import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto; import org.apache.hadoop.ipc.protobuf.TestProtos; import org.apache.hadoop.metrics2.MetricsRecordBuilder; +import org.apache.hadoop.metrics2.lib.MutableCounterLong; import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.security.AccessControlException; import org.apache.hadoop.security.SecurityUtil; @@ -64,6 +67,7 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; import java.security.PrivilegedAction; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; @@ -77,6 +81,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -85,6 +90,10 @@ import static org.apache.hadoop.test.MetricsAsserts.getLongCounter; import static org.apache.hadoop.test.MetricsAsserts.getMetrics; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.spy; @@ -1365,6 +1374,116 @@ public void testClientRpcTimeout() throws Exception { } } + public static class FakeRequestClass extends RpcWritable { + static volatile IOException exception; + @Override + void writeTo(ResponseBuffer out) throws IOException { + throw new UnsupportedOperationException(); + } + @Override + T readFrom(ByteBuffer bb) throws IOException { + throw exception; + } + } + + @SuppressWarnings("serial") + public static class TestReaderException extends IOException { + public TestReaderException(String msg) { + super(msg); + } + @Override + public boolean equals(Object t) { + return (t.getClass() == TestReaderException.class) && + getMessage().equals(((TestReaderException)t).getMessage()); + } + } + + @Test (timeout=30000) + public void testReaderExceptions() throws Exception { + Server server = null; + TestRpcService proxy = null; + + // will attempt to return this exception from a reader with and w/o + // the connection closing. + IOException expectedIOE = new TestReaderException("testing123"); + + @SuppressWarnings("serial") + IOException rseError = new RpcServerException("keepalive", expectedIOE){ + @Override + public RpcStatusProto getRpcStatusProto() { + return RpcStatusProto.ERROR; + } + }; + @SuppressWarnings("serial") + IOException rseFatal = new RpcServerException("disconnect", expectedIOE) { + @Override + public RpcStatusProto getRpcStatusProto() { + return RpcStatusProto.FATAL; + } + }; + + try { + RPC.Builder builder = newServerBuilder(conf) + .setQueueSizePerHandler(1).setNumHandlers(1).setVerbose(true); + server = setupTestServer(builder); + Whitebox.setInternalState( + server, "rpcRequestClass", FakeRequestClass.class); + MutableCounterLong authMetric = + (MutableCounterLong)Whitebox.getInternalState( + server.getRpcMetrics(), "rpcAuthorizationSuccesses"); + + proxy = getClient(addr, conf); + boolean isDisconnected = true; + Connection lastConn = null; + long expectedAuths = 0; + + // fuzz the client. + for (int i=0; i < 128; i++) { + String reqName = "request[" + i + "]"; + int r = ThreadLocalRandom.current().nextInt(); + final boolean doDisconnect = r % 4 == 0; + LOG.info("TestDisconnect request[" + i + "] " + + " shouldConnect=" + isDisconnected + + " willDisconnect=" + doDisconnect); + if (isDisconnected) { + expectedAuths++; + } + try { + FakeRequestClass.exception = doDisconnect ? rseFatal : rseError; + proxy.ping(null, newEmptyRequest()); + fail(reqName + " didn't fail"); + } catch (ServiceException e) { + RemoteException re = (RemoteException)e.getCause(); + assertEquals(reqName, expectedIOE, re.unwrapRemoteException()); + } + // check authorizations to ensure new connection when expected, + // then conclusively determine if connections are disconnected + // correctly. + assertEquals(reqName, expectedAuths, authMetric.value()); + if (!doDisconnect) { + // if it wasn't fatal, verify there's only one open connection. + Connection[] conns = server.getConnections(); + assertEquals(reqName, 1, conns.length); + // verify whether the connection should have been reused. + if (isDisconnected) { + assertNotSame(reqName, lastConn, conns[0]); + } else { + assertSame(reqName, lastConn, conns[0]); + } + lastConn = conns[0]; + } else if (lastConn != null) { + // avoid race condition in server where connection may not be + // fully removed yet. just make sure it's marked for being closed. + // the open connection checks above ensure correct behavior. + assertTrue(reqName, lastConn.shouldClose()); + } + isDisconnected = doDisconnect; + } + } finally { + stop(server, proxy); + } + } + public static void main(String[] args) throws Exception { new TestRPC().testCallsInternal(conf); }