Issue #3290 - fixing issues with WebSocketClose

introduce channelState check in the catch in WSChannel sendFrame
to guard from multiple closes

WebSocketConnection fillAndParse will now try to read until EOF

removed state change in the isOutputOpen check in webSocketChannelState
to as we do the state change in the catch block in WSChannel

added and improved WebSocketCloseTest to test more cases

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2019-02-07 09:58:43 +11:00
parent ff1f3ca3be
commit 1709c90286
6 changed files with 112 additions and 68 deletions

View File

@ -172,6 +172,19 @@ public class CloseStatus
return null;
}
public static boolean isOrdinary(CloseStatus closeStatus)
{
switch (closeStatus.getCode())
{
case NORMAL:
case SHUTDOWN:
return true;
default:
return false;
}
}
public int getCode()
{
return code;

View File

@ -527,7 +527,7 @@ public class WebSocketChannel implements IncomingFrames, FrameHandler.CoreSessio
if (frame.getOpCode() == OpCode.CLOSE)
{
CloseStatus closeStatus = CloseStatus.getCloseStatus(frame);
if (closeStatus instanceof AbnormalCloseStatus)
if (closeStatus instanceof AbnormalCloseStatus && channelState.onClosed(closeStatus))
closeConnection(null, closeStatus, Callback.from(
()->callback.failed(ex),
x2->

View File

@ -149,11 +149,7 @@ public class WebSocketChannelState
synchronized (this)
{
if (!isOutputOpen())
{
if (opcode == OpCode.CLOSE && CloseStatus.getCloseStatus(frame) instanceof WebSocketChannel.AbnormalCloseStatus)
_channelState = State.CLOSED;
throw new IllegalStateException(_channelState.toString());
}
if (opcode == OpCode.CLOSE)
{

View File

@ -348,12 +348,10 @@ public class WebSocketConnection extends AbstractConnection implements Connectio
if (!fillingAndParsing)
throw new IllegalStateException();
if (demand > 0)
if (demand != 0)
return true;
if (demand == 0)
fillingAndParsing = false;
fillingAndParsing = false;
if (networkBuffer.isEmpty())
releaseNetworkBuffer();
@ -373,10 +371,9 @@ public class WebSocketConnection extends AbstractConnection implements Connectio
if (!fillingAndParsing)
throw new IllegalStateException();
if (demand < 0)
return false;
if (demand > 0)
demand--;
demand--;
return true;
}
}
@ -412,7 +409,6 @@ public class WebSocketConnection extends AbstractConnection implements Connectio
if (!moreDemand())
return;
}
}
// buffer must be empty here because parser is fully consuming
@ -532,43 +528,6 @@ public class WebSocketConnection extends AbstractConnection implements Connectio
generator);
}
@Override
public int hashCode()
{
final int prime = 31;
int result = 1;
EndPoint endp = getEndPoint();
if (endp != null)
{
result = prime * result + endp.getLocalAddress().hashCode();
result = prime * result + endp.getRemoteAddress().hashCode();
}
return result;
}
@Override
public boolean equals(Object obj)
{
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
WebSocketConnection other = (WebSocketConnection)obj;
EndPoint endp = getEndPoint();
EndPoint otherEndp = other.getEndPoint();
if (endp == null)
{
if (otherEndp != null)
return false;
}
else if (!endp.equals(otherEndp))
return false;
return true;
}
/**
* Extra bytes from the initial HTTP upgrade that need to
* be processed by the websocket parser before starting

View File

@ -118,7 +118,6 @@ public class WebSocketCloseTest extends WebSocketTester
server.handler.getCoreSession().demand(1);
client.getOutputStream().write(RawFrameBuilder.buildClose(new CloseStatus(CloseStatus.NORMAL), true));
Frame frame = serverHandler.receivedFrames.poll(10, TimeUnit.SECONDS);
assertNotNull(frame);
assertThat(new CloseStatus(frame.getPayload()).getCode(), is(CloseStatus.NORMAL));
assertThat(server.handler.getCoreSession().toString(), containsString("ISHUT"));
@ -143,9 +142,8 @@ public class WebSocketCloseTest extends WebSocketTester
}
server.sendFrame(CloseStatus.toFrame(CloseStatus.NORMAL));
Frame frame = receiveFrame(client.getInputStream());
assertNotNull(frame);
assertThat(new CloseStatus(frame.getPayload()).getCode(), is(CloseStatus.NORMAL));
CloseStatus closeStatus = new CloseStatus(receiveFrame(client.getInputStream()));
assertThat(closeStatus.getCode(), is(CloseStatus.NORMAL));
assertThat(server.handler.getCoreSession().toString(), containsString("OSHUT"));
LOG.info("Server: OSHUT");
@ -162,7 +160,6 @@ public class WebSocketCloseTest extends WebSocketTester
server.handler.receivedCallback.poll().succeeded();
Frame frame = receiveFrame(client.getInputStream());
assertNotNull(frame);
assertThat(new CloseStatus(frame.getPayload()).getCode(), is(CloseStatus.NORMAL));
assertTrue(server.handler.closed.await(10, TimeUnit.SECONDS));
@ -177,7 +174,6 @@ public class WebSocketCloseTest extends WebSocketTester
server.sendFrame(CloseStatus.toFrame(CloseStatus.SHUTDOWN));
server.handler.receivedCallback.poll().succeeded();
Frame frame = receiveFrame(client.getInputStream());
assertNotNull(frame);
assertThat(new CloseStatus(frame.getPayload()).getCode(), is(CloseStatus.SHUTDOWN));
assertTrue(server.handler.closed.await(10, TimeUnit.SECONDS));
@ -190,14 +186,27 @@ public class WebSocketCloseTest extends WebSocketTester
setup(State.ISHUT);
server.handler.receivedCallback.poll().failed(new Exception("test failure"));
Frame frame = receiveFrame(client.getInputStream());
assertNotNull(frame);
assertThat(new CloseStatus(frame.getPayload()).getCode(), is(CloseStatus.SERVER_ERROR));
CloseStatus closeStatus = new CloseStatus(receiveFrame(client.getInputStream()));
assertThat(closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));
assertThat(closeStatus.getReason(), is("test failure"));
assertTrue(server.handler.closed.await(10, TimeUnit.SECONDS));
assertThat(server.handler.closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));
}
@Test
public void clientClosesOutput_ISHUT() throws Exception
{
setup(State.ISHUT);
client.shutdownOutput();
assertFalse(server.handler.closed.await(250, TimeUnit.MILLISECONDS));
server.handler.receivedCallback.poll().succeeded();
CloseStatus closeStatus = new CloseStatus(receiveFrame(client.getInputStream()));
assertThat(closeStatus.getCode(), is(CloseStatus.NORMAL));
}
@Test
public void clientClose_OSHUT() throws Exception
{
@ -276,11 +285,57 @@ public class WebSocketCloseTest extends WebSocketTester
setup(State.ISHUT);
client.getOutputStream().write(RawFrameBuilder.buildFrame(OpCode.PONG, "pong frame not masked", false));
assertFalse(server.handler.closed.await(250, TimeUnit.MILLISECONDS));
assertTrue(server.handler.closed.await(5, TimeUnit.SECONDS));
assertThat(server.handler.closeStatus.getCode(), is(CloseStatus.PROTOCOL));
server.close();
Frame frame = receiveFrame(client.getInputStream());
assertThat(CloseStatus.getCloseStatus(frame).getCode(), is(CloseStatus.PROTOCOL));
receiveEof(client.getInputStream());
}
@Test
public void clientHalfClose_ISHUT() throws Exception
{
setup(State.ISHUT);
client.shutdownOutput();
assertFalse(server.handler.closed.await(250, TimeUnit.MILLISECONDS));
Callback callback = server.handler.receivedCallback.poll(5, TimeUnit.SECONDS);
callback.succeeded();
assertTrue(server.handler.closed.await(5, TimeUnit.SECONDS));
assertThat(server.handler.closeStatus.getCode(), is(CloseStatus.NORMAL));
Frame frame = receiveFrame(client.getInputStream());
assertThat(CloseStatus.getCloseStatus(frame).getCode(), is(CloseStatus.NORMAL));
receiveEof(client.getInputStream());
}
@Test
public void clientCloseServerWrite_ISHUT() throws Exception
{
setup(State.ISHUT);
client.close();
assertFalse(server.handler.closed.await(250, TimeUnit.MILLISECONDS));
while(true)
{
if (!server.isOpen())
break;
Callback callback = Callback.from(()->System.err.println("Succeeded Frame After Close"),
(t)->System.err.println("Failed Frame After Close"));
server.sendFrame(new Frame(OpCode.TEXT, BufferUtil.toBuffer("frame after close")), callback);
}
assertTrue(server.handler.closed.await(5, TimeUnit.SECONDS));
assertNotNull(server.handler.error);
assertThat(server.handler.closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));
Callback callback = server.handler.receivedCallback.poll(5, TimeUnit.SECONDS);
callback.succeeded();
assertThat(server.handler.closeStatus.getCode(), is(CloseStatus.SERVER_ERROR));
}
@Test
@ -360,6 +415,7 @@ public class WebSocketCloseTest extends WebSocketTester
protected BlockingQueue<Frame> receivedFrames = new BlockingArrayQueue<>();
protected BlockingQueue<Callback> receivedCallback = new BlockingArrayQueue<>();
protected volatile Throwable error = null;
protected CountDownLatch opened = new CountDownLatch(1);
protected CountDownLatch closed = new CountDownLatch(1);
protected CloseStatus closeStatus = null;
@ -408,6 +464,7 @@ public class WebSocketCloseTest extends WebSocketTester
public void onError(Throwable cause)
{
LOG.info("onError {} ", cause == null?null:cause.toString());
error = cause;
state = session.toString();
}
@ -475,6 +532,11 @@ public class WebSocketCloseTest extends WebSocketTester
handler.getCoreSession().sendFrame(frame, NOOP, false);
}
public void sendFrame(Frame frame, Callback callback)
{
handler.getCoreSession().sendFrame(frame, callback, false);
}
public void sendText(String line)
{
LOG.info("sending {}...", line);

View File

@ -18,6 +18,13 @@
package org.eclipse.jetty.websocket.core;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.io.ArrayByteBufferPool;
@ -27,13 +34,6 @@ import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.websocket.core.internal.Parser;
import org.junit.jupiter.api.BeforeEach;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.startsWith;
@ -124,4 +124,18 @@ public class WebSocketTester
return frame;
}
}
protected void receiveEof(InputStream in) throws IOException
{
ByteBuffer buffer = bufferPool.acquire(4096, false);
while (true)
{
BufferUtil.flipToFill(buffer);
int len = in.read(buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining());
if (len < 0)
return;
throw new IllegalStateException("unexpected content");
}
}
}