Issue #3809 - ensure abnormal close frame will hard close ws connection (#3819)

* Issue #3809 - ensure abnormal close frame will hard close ws connection

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>

* Issue #3159 - signal onError on abnormal status code close

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan 2019-07-01 16:50:56 +10:00 committed by GitHub
parent fa4abfa6bb
commit ac8910e044
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 111 additions and 89 deletions

View File

@ -52,13 +52,14 @@ public class CloseStatus
private final int code;
private final String reason;
private final Throwable cause;
/**
* Creates a reason for closing a web socket connection with the no given status code.
*/
public CloseStatus()
{
this(NO_CODE);
this(NO_CODE, null, null);
}
/**
@ -68,7 +69,7 @@ public class CloseStatus
*/
public CloseStatus(int statusCode)
{
this(statusCode, null);
this(statusCode, null, null);
}
/**
@ -78,8 +79,32 @@ public class CloseStatus
* @param reasonPhrase the reason phrase
*/
public CloseStatus(int statusCode, String reasonPhrase)
{
this(statusCode, reasonPhrase, null);
}
/**
* Creates a reason for closing a web socket connection with the given status code and reason phrase.
*
* @param statusCode the close code
* @param cause the error which caused the close
*/
public CloseStatus(int statusCode, Throwable cause)
{
this(statusCode, cause.getMessage(), cause);
}
/**
* Creates a reason for closing a web socket connection with the given status code and reason phrase.
*
* @param statusCode the close code
* @param reasonPhrase the reason phrase
* @param cause the error which caused the close
*/
public CloseStatus(int statusCode, String reasonPhrase, Throwable cause)
{
this.code = statusCode;
this.cause = cause;
if (reasonPhrase != null)
{
@ -100,6 +125,7 @@ public class CloseStatus
public CloseStatus(ByteBuffer payload)
{
// RFC-6455 Spec Required Close Frame validation.
this.cause = null;
int statusCode = NO_CODE;
if ((payload == null) || (payload.remaining() == 0))
@ -169,14 +195,22 @@ public class CloseStatus
return ((CloseStatus.Supplier)frame).getCloseStatus();
if (frame.getOpCode() == OpCode.CLOSE)
return new CloseStatus(frame);
return null;
throw new IllegalArgumentException("not a close frame");
}
// TODO consider defining a precedence for every CloseStatus, and change SessionState only if higher precedence
public static boolean isOrdinary(CloseStatus closeStatus)
public static boolean isOrdinary(int closeCode)
{
int code = closeStatus.getCode();
return (code == NORMAL || code == NO_CODE || code >= 3000);
return (closeCode == NORMAL || closeCode == NO_CODE || closeCode >= 3000);
}
public boolean isAbnormal()
{
return !isOrdinary(code);
}
public Throwable getCause()
{
return cause;
}
public int getCode()

View File

@ -44,8 +44,6 @@ import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.core.WebSocketException;
import org.eclipse.jetty.websocket.core.WebSocketWriteTimeoutException;
import static org.eclipse.jetty.websocket.core.internal.WebSocketCoreSession.AbnormalCloseStatus;
public class FrameFlusher extends IteratingCallback
{
public static final Frame FLUSH_FRAME = new Frame(OpCode.BINARY);
@ -115,7 +113,7 @@ public class FrameFlusher extends IteratingCallback
{
case OpCode.CLOSE:
closeStatus = CloseStatus.getCloseStatus(frame);
if (!CloseStatus.isOrdinary(closeStatus))
if (closeStatus.isAbnormal())
{
//fail all existing entries in the queue, and enqueue the error close
failedEntries = new ArrayList<>(queue);
@ -151,12 +149,8 @@ public class FrameFlusher extends IteratingCallback
if (failedEntries != null)
{
WebSocketException failure = new WebSocketException("Flusher received abnormal CloseFrame: " + CloseStatus.codeString(closeStatus.getCode()));
if (closeStatus instanceof AbnormalCloseStatus)
{
Throwable cause = ((AbnormalCloseStatus)closeStatus).getCause();
failure.initCause(cause);
}
WebSocketException failure = new WebSocketException("Flusher received abnormal CloseFrame: "
+ CloseStatus.codeString(closeStatus.getCode()), closeStatus.getCause());
for (Entry e : failedEntries)
{

View File

@ -114,11 +114,38 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
throw new ProtocolException("Server MUST NOT mask any frames (RFC-6455: Section 5.1)");
break;
}
/*
* RFC 6455 Section 5.5.1
* close frame payload is specially formatted which is checked in CloseStatus
*/
if (frame.getOpCode() == OpCode.CLOSE)
{
if (!(frame instanceof ParsedFrame)) // already check in parser
CloseStatus.getCloseStatus(frame); // return ignored as get used to validate there is a closeStatus
}
}
public void assertValidOutgoing(Frame frame) throws CloseException
{
assertValidFrame(frame);
/*
* RFC 6455 Section 5.5.1
* close frame payload is specially formatted which is checked in CloseStatus
*/
if (frame.getOpCode() == OpCode.CLOSE)
{
if (!(frame instanceof ParsedFrame)) // already check in parser
{
CloseStatus closeStatus = CloseStatus.getCloseStatus(frame);
if (!CloseStatus.isTransmittableStatusCode(closeStatus.getCode()) && (closeStatus.getCode()!=CloseStatus.NO_CODE))
{
throw new ProtocolException("Frame has non-transmittable status code");
}
}
}
}
public void assertValidFrame(Frame frame)
@ -141,16 +168,6 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
throw new ProtocolException("Cannot have RSV2==true on Control frames");
if (frame.isRsv3())
throw new ProtocolException("Cannot have RSV3==true on Control frames");
/*
* RFC 6455 Section 5.5.1
* close frame payload is specially formatted which is checked in CloseStatus
*/
if (frame.getOpCode() == OpCode.CLOSE)
{
if (!(frame instanceof ParsedFrame)) // already check in parser
CloseStatus.getCloseStatus(frame); // return ignored as get used to validate there is a closeStatus
}
}
else
{
@ -283,20 +300,20 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
LOG.debug("onEof() {}", this);
if (sessionState.onEof())
closeConnection(new ClosedChannelException(), sessionState.getCloseStatus(), Callback.NOOP);
closeConnection(sessionState.getCloseStatus(), Callback.NOOP);
}
public void closeConnection(Throwable cause, CloseStatus closeStatus, Callback callback)
public void closeConnection(CloseStatus closeStatus, Callback callback)
{
if (LOG.isDebugEnabled())
LOG.debug("closeConnection() {} {} {}", closeStatus, this, cause);
LOG.debug("closeConnection() {} {} {}", closeStatus, this);
connection.cancelDemand();
if (connection.getEndPoint().isOpen())
connection.close();
// Forward Errors to Local WebSocket EndPoint
if (cause != null)
if (closeStatus.isAbnormal())
{
Callback errorCallback = Callback.from(() ->
{
@ -311,6 +328,7 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
}
});
Throwable cause = closeStatus.getCause() != null ? closeStatus.getCause() : new ClosedChannelException();
try
{
handler.onError(cause, errorCallback);
@ -362,13 +380,13 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
else
code = CloseStatus.NO_CLOSE;
AbnormalCloseStatus closeStatus = new AbnormalCloseStatus(code, cause);
CloseStatus closeStatus = new CloseStatus(code, cause);
if (CloseStatus.isTransmittableStatusCode(code))
close(closeStatus, callback);
else
{
if (sessionState.onClosed(closeStatus))
closeConnection(cause, closeStatus, callback);
closeConnection(closeStatus, callback);
}
}
@ -396,7 +414,7 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
else
code = CloseStatus.SERVER_ERROR;
close(new AbnormalCloseStatus(code, cause), callback);
close(new CloseStatus(code, cause), callback);
}
/**
@ -492,7 +510,7 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
catch (Throwable t)
{
if (LOG.isDebugEnabled())
LOG.warn("Invalid outgoing frame: {}", frame);
LOG.warn("Invalid outgoing frame: " + frame, t);
callback.failed(t);
return;
@ -508,11 +526,9 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
boolean closeConnection = sessionState.onOutgoingFrame(frame);
if (closeConnection)
{
Throwable cause = AbnormalCloseStatus.getCause(CloseStatus.getCloseStatus(frame));
Callback closeConnectionCallback = Callback.from(
() -> closeConnection(cause, sessionState.getCloseStatus(), callback),
t -> closeConnection(cause, sessionState.getCloseStatus(), Callback.from(callback, t)));
() -> closeConnection(sessionState.getCloseStatus(), callback),
t -> closeConnection(sessionState.getCloseStatus(), Callback.from(callback, t)));
flusher.queue.offer(new FrameEntry(frame, closeConnectionCallback, false));
}
@ -531,8 +547,8 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
if (frame.getOpCode() == OpCode.CLOSE)
{
CloseStatus closeStatus = CloseStatus.getCloseStatus(frame);
if (closeStatus instanceof AbnormalCloseStatus && sessionState.onClosed(closeStatus))
closeConnection(AbnormalCloseStatus.getCause(closeStatus), closeStatus, Callback.from(callback, t));
if (closeStatus.isAbnormal() && sessionState.onClosed(closeStatus))
closeConnection(closeStatus, Callback.from(callback, t));
else
callback.failed(t);
}
@ -660,7 +676,7 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
if (closeConnection)
{
closeCallback = Callback.from(() -> closeConnection(null, sessionState.getCloseStatus(), callback));
closeCallback = Callback.from(() -> closeConnection(sessionState.getCloseStatus(), callback));
}
else
{
@ -773,35 +789,6 @@ public class WebSocketCoreSession implements IncomingFrames, FrameHandler.CoreSe
handler);
}
static class AbnormalCloseStatus extends CloseStatus
{
final Throwable cause;
public AbnormalCloseStatus(int statusCode, Throwable cause)
{
super(statusCode, cause.getMessage());
this.cause = cause;
}
public Throwable getCause()
{
return cause;
}
public static Throwable getCause(CloseStatus status)
{
if (status instanceof AbnormalCloseStatus)
return ((AbnormalCloseStatus)status).getCause();
return null;
}
@Override
public String toString()
{
return "Abnormal" + super.toString() + ":" + cause;
}
}
private class Flusher extends IteratingCallback
{
private final Queue<FrameEntry> queue = new ArrayDeque<>();

View File

@ -132,7 +132,7 @@ public class WebSocketSessionState
return false;
default:
if (_closeStatus == null || CloseStatus.isOrdinary(_closeStatus))
if (_closeStatus == null || CloseStatus.isOrdinary(_closeStatus.getCode()))
_closeStatus = new CloseStatus(CloseStatus.NO_CLOSE, "Session Closed");
_sessionState = State.CLOSED;
return true;
@ -153,7 +153,7 @@ public class WebSocketSessionState
if (opcode == OpCode.CLOSE)
{
_closeStatus = CloseStatus.getCloseStatus(frame);
if (_closeStatus instanceof WebSocketCoreSession.AbnormalCloseStatus)
if (_closeStatus.isAbnormal())
{
_sessionState = State.CLOSED;
return true;

View File

@ -411,6 +411,21 @@ public class WebSocketCloseTest extends WebSocketTester
assertThat(server.handler.closeStatus.getReason(), containsString("onReceiveFrame throws for binary frames"));
}
@ParameterizedTest
@ValueSource(strings = {WS_SCHEME, WSS_SCHEME})
public void abnormalCloseStatusIsHardClose(String scheme) throws Exception
{
setup(State.OPEN, scheme);
server.handler.getCoreSession().close(CloseStatus.SERVER_ERROR, "manually sent server error", Callback.NOOP);
assertTrue(server.handler.closed.await(5, TimeUnit.SECONDS));
assertThat(server.handler.closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));
assertThat(server.handler.closeStatus.getReason(), containsString("manually sent server error"));
Frame frame = receiveFrame(client.getInputStream());
assertThat(CloseStatus.getCloseStatus(frame).getCode(), is(CloseStatus.SERVER_ERROR));
}
static class DemandingTestFrameHandler implements SynchronousFrameHandler
{
private CoreSession coreSession;

View File

@ -75,7 +75,8 @@ public class WebSocketOpenTest extends WebSocketTester
setup((s, c) ->
{
assertThat(s.toString(), containsString("CONNECTED"));
WebSocketOpenTest.TestFrameHandler.sendText(s, "Hello", c);
WebSocketOpenTest.TestFrameHandler.sendText(s, "Hello", Callback.NOOP);
c.succeeded();
s.demand(1);
return null;
});
@ -122,7 +123,6 @@ public class WebSocketOpenTest extends WebSocketTester
{
assertThat(s.toString(), containsString("CONNECTED"));
s.close(CloseStatus.SHUTDOWN, "Test close in onOpen", c);
s.demand(1);
return null;
});
@ -132,7 +132,7 @@ public class WebSocketOpenTest extends WebSocketTester
client.getOutputStream().write(RawFrameBuilder.buildClose(new CloseStatus(CloseStatus.NORMAL), true));
assertTrue(serverHandler.onClosed.await(5, TimeUnit.SECONDS));
assertThat(serverHandler.closeStatus.getCode(), is(CloseStatus.NORMAL));
assertThat(serverHandler.closeStatus.getCode(), is(CloseStatus.SHUTDOWN));
}
@Test

View File

@ -25,7 +25,6 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -252,17 +251,13 @@ public class WebSocketProxyTest
CloseStatus closeStatus = CloseStatus.getCloseStatus(proxyClientSide.receivedFrames.poll());
assertThat(closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));
assertThat(closeStatus.getReason(), containsString("simulated client onOpen error"));
assertThat(proxyClientSide.getState(), is(WebSocketProxy.State.CLOSED));
closeStatus = CloseStatus.getCloseStatus(proxyServerSide.receivedFrames.poll());
assertThat(closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));
assertThat(closeStatus.getReason(), containsString("simulated client onOpen error"));
assertThat(proxyServerSide.getState(), is(WebSocketProxy.State.CLOSED));
assertThat(proxyClientSide.getState(), is(WebSocketProxy.State.FAILED));
closeStatus = CloseStatus.getCloseStatus(serverFrameHandler.receivedFrames.poll());
assertThat(closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));
assertThat(closeStatus.getReason(), containsString("simulated client onOpen error"));
assertNull(proxyServerSide.receivedFrames.poll());
assertNull(clientFrameHandler.receivedFrames.poll());
}
@ -311,17 +306,14 @@ public class WebSocketProxyTest
assertThat(closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));
assertThat(closeStatus.getReason(), is("intentionally throwing in server onFrame()"));
// Client2Proxy receiving close response from Client
frame = proxyClientSide.receivedFrames.poll();
closeStatus = CloseStatus.getCloseStatus(frame);
assertThat(closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));
assertThat(closeStatus.getReason(), is("intentionally throwing in server onFrame()"));
// Client2Proxy receives no close response because is error close
assertNull(proxyClientSide.receivedFrames.poll());
// Check Proxy is in expected final state
assertNull(proxyClientSide.receivedFrames.poll());
assertNull(proxyServerSide.receivedFrames.poll());
assertThat(proxyClientSide.getState(), is(WebSocketProxy.State.CLOSED));
assertThat(proxyServerSide.getState(), is(WebSocketProxy.State.CLOSED));
assertThat(proxyClientSide.getState(), is(WebSocketProxy.State.FAILED));
assertThat(proxyServerSide.getState(), is(WebSocketProxy.State.FAILED));
}
@Test