Issue #4104 - WebSocketSession will reject outgoing frames if closed

Outgoing frames will now go RemoteEndpoint->Session->ExtensionStack
instead of just RemoteEndpoint->ExtensionStack.

This will allow the Session to check whether it has been closed before
allowing the frame through the ExtensionStack.

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2019-09-20 10:30:31 +10:00
parent 42f1214796
commit 5a52235464
7 changed files with 46 additions and 59 deletions

View File

@ -37,7 +37,7 @@ public class EventSocket
{ {
private static Logger LOG = Log.getLogger(EventSocket.class); private static Logger LOG = Log.getLogger(EventSocket.class);
protected Session session; public Session session;
private String behavior; private String behavior;
public volatile Throwable failure = null; public volatile Throwable failure = null;
public volatile int closeCode = -1; public volatile int closeCode = -1;

View File

@ -19,20 +19,17 @@
package org.eclipse.jetty.websocket.tests; package org.eclipse.jetty.websocket.tests;
import java.net.URI; import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder; import org.eclipse.jetty.servlet.ServletHolder;
import org.eclipse.jetty.util.log.StacklessLogging;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode; import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.WebSocketException;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.common.WebSocketSession;
import org.eclipse.jetty.websocket.common.extensions.compress.CompressExtension;
import org.eclipse.jetty.websocket.servlet.WebSocketServlet; import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
@ -105,12 +102,8 @@ public class WriteAfterStopTest
assertThat(clientSocket.closeCode, is(StatusCode.NORMAL)); assertThat(clientSocket.closeCode, is(StatusCode.NORMAL));
assertThat(serverSocket.closeCode, is(StatusCode.NORMAL)); assertThat(serverSocket.closeCode, is(StatusCode.NORMAL));
((WebSocketSession)session).stop(); WebSocketException failure = assertThrows(WebSocketException.class, () ->
clientSocket.session.getRemote().sendString("this should fail before ExtensionStack"));
try (StacklessLogging stacklessLogging = new StacklessLogging(CompressExtension.class)) assertThat(failure.getMessage(), is("Session closed"));
{
assertThrows(ClosedChannelException.class,
() -> session.getRemote().sendString("hello world"));
}
} }
} }

View File

@ -601,7 +601,6 @@ public class WebSocketUpgradeRequest extends HttpRequest implements CompleteList
session.setUpgradeResponse(new ClientUpgradeResponse(response)); session.setUpgradeResponse(new ClientUpgradeResponse(response));
connection.addListener(session); connection.addListener(session);
ExtensionStack extensionStack = new ExtensionStack(getExtensionFactory());
List<ExtensionConfig> extensions = new ArrayList<>(); List<ExtensionConfig> extensions = new ArrayList<>();
HttpField extField = response.getHeaders().getField(HttpHeader.SEC_WEBSOCKET_EXTENSIONS); HttpField extField = response.getHeaders().getField(HttpHeader.SEC_WEBSOCKET_EXTENSIONS);
if (extField != null) if (extField != null)
@ -619,8 +618,9 @@ public class WebSocketUpgradeRequest extends HttpRequest implements CompleteList
} }
} }
} }
extensionStack.negotiate(extensions);
ExtensionStack extensionStack = new ExtensionStack(getExtensionFactory());
extensionStack.negotiate(extensions);
extensionStack.configure(connection.getParser()); extensionStack.configure(connection.getParser());
extensionStack.configure(connection.getGenerator()); extensionStack.configure(connection.getGenerator());

View File

@ -18,7 +18,6 @@
package org.eclipse.jetty.websocket.client.io; package org.eclipse.jetty.websocket.client.io;
import java.net.InetSocketAddress;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import org.eclipse.jetty.io.ByteBufferPool; import org.eclipse.jetty.io.ByteBufferPool;
@ -28,7 +27,6 @@ import org.eclipse.jetty.websocket.api.BatchMode;
import org.eclipse.jetty.websocket.api.WebSocketPolicy; import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.api.WriteCallback; import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.Frame; import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.api.extensions.IncomingFrames;
import org.eclipse.jetty.websocket.client.masks.Masker; import org.eclipse.jetty.websocket.client.masks.Masker;
import org.eclipse.jetty.websocket.client.masks.RandomMasker; import org.eclipse.jetty.websocket.client.masks.RandomMasker;
import org.eclipse.jetty.websocket.common.WebSocketFrame; import org.eclipse.jetty.websocket.common.WebSocketFrame;
@ -47,18 +45,6 @@ public class WebSocketClientConnection extends AbstractWebSocketConnection
this.masker = new RandomMasker(); this.masker = new RandomMasker();
} }
@Override
public InetSocketAddress getLocalAddress()
{
return getEndPoint().getLocalAddress();
}
@Override
public InetSocketAddress getRemoteAddress()
{
return getEndPoint().getRemoteAddress();
}
/** /**
* Override to set the masker. * Override to set the masker.
*/ */
@ -71,10 +57,4 @@ public class WebSocketClientConnection extends AbstractWebSocketConnection
} }
super.outgoingFrame(frame, callback, batchMode); super.outgoingFrame(frame, callback, batchMode);
} }
@Override
public void setNextIncomingFrames(IncomingFrames incoming)
{
getParser().setIncomingFramesHandler(incoming);
}
} }

View File

@ -48,7 +48,9 @@ import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.SuspendToken; import org.eclipse.jetty.websocket.api.SuspendToken;
import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.UpgradeResponse; import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.eclipse.jetty.websocket.api.WebSocketException;
import org.eclipse.jetty.websocket.api.WebSocketPolicy; import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.ExtensionFactory; import org.eclipse.jetty.websocket.api.extensions.ExtensionFactory;
import org.eclipse.jetty.websocket.api.extensions.Frame; import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.api.extensions.IncomingFrames; import org.eclipse.jetty.websocket.api.extensions.IncomingFrames;
@ -59,7 +61,7 @@ import org.eclipse.jetty.websocket.common.scopes.WebSocketContainerScope;
import org.eclipse.jetty.websocket.common.scopes.WebSocketSessionScope; import org.eclipse.jetty.websocket.common.scopes.WebSocketSessionScope;
@ManagedObject("A Jetty WebSocket Session") @ManagedObject("A Jetty WebSocket Session")
public class WebSocketSession extends ContainerLifeCycle implements Session, RemoteEndpointFactory, WebSocketSessionScope, IncomingFrames, Connection.Listener public class WebSocketSession extends ContainerLifeCycle implements Session, RemoteEndpointFactory, WebSocketSessionScope, IncomingFrames, OutgoingFrames, Connection.Listener
{ {
private static final Logger LOG = Log.getLogger(WebSocketSession.class); private static final Logger LOG = Log.getLogger(WebSocketSession.class);
private final WebSocketContainerScope containerScope; private final WebSocketContainerScope containerScope;
@ -334,6 +336,26 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Rem
} }
} }
@Override
public void outgoingFrame(Frame frame, WriteCallback callback, BatchMode batchMode)
{
if (onCloseCalled.get())
{
try
{
if (callback != null)
callback.writeFailed(new WebSocketException("Session closed"));
}
catch (Throwable x)
{
LOG.debug("Exception while notifying failure of callback " + callback, x);
}
return;
}
outgoingHandler.outgoingFrame(frame, callback, batchMode);
}
@Override @Override
public boolean isOpen() public boolean isOpen()
{ {
@ -420,7 +442,7 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Rem
@Override @Override
public WebSocketRemoteEndpoint newRemoteEndpoint(LogicalConnection connection, OutgoingFrames outgoingFrames, BatchMode batchMode) public WebSocketRemoteEndpoint newRemoteEndpoint(LogicalConnection connection, OutgoingFrames outgoingFrames, BatchMode batchMode)
{ {
return new WebSocketRemoteEndpoint(connection, outgoingHandler, getBatchMode()); return new WebSocketRemoteEndpoint(connection, outgoingFrames, getBatchMode());
} }
/** /**
@ -443,7 +465,7 @@ public class WebSocketSession extends ContainerLifeCycle implements Session, Rem
if (connection.opening()) if (connection.opening())
{ {
// Connect remote // Connect remote
remote = remoteEndpointFactory.newRemoteEndpoint(connection, outgoingHandler, getBatchMode()); remote = remoteEndpointFactory.newRemoteEndpoint(connection, this, getBatchMode());
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("[{}] {}.open() remote={}", policy.getBehavior(), this.getClass().getSimpleName(), remote); LOG.debug("[{}] {}.open() remote={}", policy.getBehavior(), this.getClass().getSimpleName(), remote);

View File

@ -49,6 +49,7 @@ import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.api.WriteCallback; import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
import org.eclipse.jetty.websocket.api.extensions.Frame; import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.api.extensions.IncomingFrames;
import org.eclipse.jetty.websocket.common.CloseInfo; import org.eclipse.jetty.websocket.common.CloseInfo;
import org.eclipse.jetty.websocket.common.Generator; import org.eclipse.jetty.websocket.common.Generator;
import org.eclipse.jetty.websocket.common.LogicalConnection; import org.eclipse.jetty.websocket.common.LogicalConnection;
@ -392,6 +393,12 @@ public abstract class AbstractWebSocketConnection extends AbstractConnection imp
return this.policy; return this.policy;
} }
@Override
public InetSocketAddress getLocalAddress()
{
return getEndPoint().getLocalAddress();
}
@Override @Override
public InetSocketAddress getRemoteAddress() public InetSocketAddress getRemoteAddress()
{ {
@ -649,6 +656,12 @@ public abstract class AbstractWebSocketConnection extends AbstractConnection imp
setInitialBuffer(prefilled); setInitialBuffer(prefilled);
} }
@Override
public void setNextIncomingFrames(IncomingFrames incoming)
{
getParser().setIncomingFramesHandler(incoming);
}
/** /**
* @return the number of WebSocket frames received over this connection * @return the number of WebSocket frames received over this connection
*/ */

View File

@ -18,7 +18,6 @@
package org.eclipse.jetty.websocket.server; package org.eclipse.jetty.websocket.server;
import java.net.InetSocketAddress;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import org.eclipse.jetty.io.ByteBufferPool; import org.eclipse.jetty.io.ByteBufferPool;
@ -26,7 +25,6 @@ import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint; import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.util.thread.Scheduler; import org.eclipse.jetty.util.thread.Scheduler;
import org.eclipse.jetty.websocket.api.WebSocketPolicy; import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.api.extensions.IncomingFrames;
import org.eclipse.jetty.websocket.common.io.AbstractWebSocketConnection; import org.eclipse.jetty.websocket.common.io.AbstractWebSocketConnection;
public class WebSocketServerConnection extends AbstractWebSocketConnection implements Connection.UpgradeTo public class WebSocketServerConnection extends AbstractWebSocketConnection implements Connection.UpgradeTo
@ -34,27 +32,8 @@ public class WebSocketServerConnection extends AbstractWebSocketConnection imple
public WebSocketServerConnection(EndPoint endp, Executor executor, Scheduler scheduler, WebSocketPolicy policy, ByteBufferPool bufferPool) public WebSocketServerConnection(EndPoint endp, Executor executor, Scheduler scheduler, WebSocketPolicy policy, ByteBufferPool bufferPool)
{ {
super(endp, executor, scheduler, policy, bufferPool); super(endp, executor, scheduler, policy, bufferPool);
if (policy.getIdleTimeout() > 0) if (policy.getIdleTimeout() > 0)
{
endp.setIdleTimeout(policy.getIdleTimeout()); endp.setIdleTimeout(policy.getIdleTimeout());
} }
} }
@Override
public InetSocketAddress getLocalAddress()
{
return getEndPoint().getLocalAddress();
}
@Override
public InetSocketAddress getRemoteAddress()
{
return getEndPoint().getRemoteAddress();
}
@Override
public void setNextIncomingFrames(IncomingFrames incoming)
{
getParser().setIncomingFramesHandler(incoming);
}
}