Issue #3968 - prevent ReadPending and ISE from AbstractWebSocketConnection (#3979)

* Issue #3968 - websocket suspend fix and cleanups

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>

* Issue #3968 - fixed race conditions when using websocket ReadState

combine the previous ReadMode into ReadState by using ReadState.Action
which is returned from ReadState.getAction(ByteBuffer) where an atomic
decision is made of what action to do

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan 2019-08-14 21:28:35 +10:00 committed by GitHub
parent 8761b345b5
commit 2a109dccbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 185 additions and 214 deletions

View File

@ -121,13 +121,6 @@ public abstract class AbstractWebSocketConnection extends AbstractConnection imp
} }
} }
private enum ReadMode
{
PARSE,
DISCARD,
EOF
}
private static final Logger LOG = Log.getLogger(AbstractWebSocketConnection.class); private static final Logger LOG = Log.getLogger(AbstractWebSocketConnection.class);
private static final AtomicLong ID_GEN = new AtomicLong(0); private static final AtomicLong ID_GEN = new AtomicLong(0);
@ -148,7 +141,6 @@ public abstract class AbstractWebSocketConnection extends AbstractConnection imp
private WebSocketSession session; private WebSocketSession session;
private List<ExtensionConfig> extensions = new ArrayList<>(); private List<ExtensionConfig> extensions = new ArrayList<>();
private ByteBuffer prefillBuffer; private ByteBuffer prefillBuffer;
private ReadMode readMode = ReadMode.PARSE;
private Stats stats = new Stats(); private Stats stats = new Stats();
private CloseInfo fatalCloseInfo; private CloseInfo fatalCloseInfo;
@ -420,10 +412,11 @@ public abstract class AbstractWebSocketConnection extends AbstractConnection imp
public void onFillable() public void onFillable()
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
{
LOG.debug("{} onFillable()", policy.getBehavior()); LOG.debug("{} onFillable()", policy.getBehavior());
}
stats.countOnFillableEvents.incrementAndGet(); stats.countOnFillableEvents.incrementAndGet();
if (readState.getBuffer() != null)
throw new IllegalStateException();
ByteBuffer buffer = bufferPool.acquire(getInputBufferSize(), true); ByteBuffer buffer = bufferPool.acquire(getInputBufferSize(), true);
onFillable(buffer); onFillable(buffer);
} }
@ -431,39 +424,93 @@ public abstract class AbstractWebSocketConnection extends AbstractConnection imp
private void onFillable(ByteBuffer buffer) private void onFillable(ByteBuffer buffer)
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
{
LOG.debug("{} onFillable(ByteBuffer): {}", policy.getBehavior(), buffer); LOG.debug("{} onFillable(ByteBuffer): {}", policy.getBehavior(), buffer);
}
try while (true)
{ {
if (readMode == ReadMode.PARSE) ReadState.Action action = readState.getAction(buffer);
readMode = readParse(buffer); if (LOG.isDebugEnabled())
else LOG.debug("ReadState Action: {}", action);
readMode = readDiscard(buffer);
}
catch (Throwable t)
{
bufferPool.release(buffer);
throw t;
}
if (readMode == ReadMode.EOF) switch (action)
{ {
bufferPool.release(buffer); case PARSE:
readState.eof(); try
{
parser.parseSingleFrame(buffer);
}
catch (Throwable t)
{
close(t);
readState.discard();
}
break;
// Handle case where the remote connection was abruptly terminated without a close frame case FILL:
CloseInfo close = new CloseInfo(StatusCode.SHUTDOWN); try
close(close, new DisconnectCallback(this)); {
} int filled = getEndPoint().fill(buffer);
else if (!readState.suspend()) if (filled < 0)
{ {
bufferPool.release(buffer); readState.eof();
fillInterested(); break;
}
if (filled == 0)
{
// Done reading, wait for next onFillable
bufferPool.release(buffer);
fillInterested();
return;
}
if (LOG.isDebugEnabled())
LOG.debug("Filled {} bytes - {}", filled, BufferUtil.toDetailString(buffer));
}
catch (IOException e)
{
close(e);
readState.eof();
}
break;
case DISCARD:
if (LOG.isDebugEnabled())
LOG.debug("Discarded buffer - {}", BufferUtil.toDetailString(buffer));
buffer.clear();
break;
case SUSPEND:
return;
case EOF:
bufferPool.release(buffer);
// Handle case where the remote connection was abruptly terminated without a close frame
CloseInfo close = new CloseInfo(StatusCode.SHUTDOWN);
close(close, new DisconnectCallback(this));
return;
default:
throw new IllegalStateException(action.name());
}
} }
} }
@Override
public void resume()
{
ByteBuffer resume = readState.resume();
if (resume != null)
onFillable(resume);
}
@Override
public SuspendToken suspend()
{
readState.suspending();
return this;
}
@Override @Override
protected void onFillInterestedFailed(Throwable cause) protected void onFillInterestedFailed(Throwable cause)
{ {
@ -517,120 +564,6 @@ public abstract class AbstractWebSocketConnection extends AbstractConnection imp
} }
} }
private ReadMode readDiscard(ByteBuffer buffer)
{
EndPoint endPoint = getEndPoint();
try
{
while (true)
{
int filled = endPoint.fill(buffer);
if (filled == 0)
{
return ReadMode.DISCARD;
}
else if (filled < 0)
{
if (LOG.isDebugEnabled())
{
LOG.debug("read - EOF Reached (remote: {})", getRemoteAddress());
}
return ReadMode.EOF;
}
else
{
if (LOG.isDebugEnabled())
{
LOG.debug("Discarded {} bytes - {}", filled, BufferUtil.toDetailString(buffer));
}
}
}
}
catch (IOException e)
{
LOG.ignore(e);
return ReadMode.EOF;
}
catch (Throwable t)
{
LOG.ignore(t);
return ReadMode.DISCARD;
}
}
private ReadMode readParse(ByteBuffer buffer)
{
EndPoint endPoint = getEndPoint();
try
{
// Process the content from the Endpoint next
while (true)
{
// We may start with a non empty buffer, consume before filling
while (buffer.hasRemaining())
{
if (readState.suspendParse(buffer))
{
if (LOG.isDebugEnabled())
{
LOG.debug("suspending parse {}", buffer);
}
return ReadMode.PARSE;
}
else
parser.parseSingleFrame(buffer);
}
int filled = endPoint.fill(buffer);
if (filled < 0)
{
if (LOG.isDebugEnabled())
{
LOG.debug("read - EOF Reached (remote: {})", getRemoteAddress());
}
return ReadMode.EOF;
}
else if (filled == 0)
{
// Done reading, wait for next onFillable
return ReadMode.PARSE;
}
if (LOG.isDebugEnabled())
{
LOG.debug("Filled {} bytes - {}", filled, BufferUtil.toDetailString(buffer));
}
}
}
catch (Throwable t)
{
close(t);
return ReadMode.DISCARD;
}
}
@Override
public void resume()
{
ByteBuffer resume = readState.resume();
if (resume == null)
{
fillInterested();
}
else if (resume != ReadState.NO_ACTION)
{
onFillable(resume);
}
}
@Override
public SuspendToken suspend()
{
readState.suspending();
return this;
}
/** /**
* Get the list of extensions in use. * Get the list of extensions in use.
* <p> * <p>

View File

@ -21,14 +21,26 @@ package org.eclipse.jetty.websocket.common.io;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
class ReadState class ReadState
{ {
private static final Logger LOG = Log.getLogger(ReadState.class);
public static final ByteBuffer NO_ACTION = BufferUtil.EMPTY_BUFFER; public static final ByteBuffer NO_ACTION = BufferUtil.EMPTY_BUFFER;
private State state = State.READING; private State state = State.READING;
private ByteBuffer buffer; private ByteBuffer buffer;
public ByteBuffer getBuffer()
{
synchronized (this)
{
return buffer;
}
}
boolean isReading() boolean isReading()
{ {
synchronized (this) synchronized (this)
@ -45,8 +57,38 @@ class ReadState
} }
} }
public Action getAction(ByteBuffer buffer)
{
synchronized (this)
{
if (LOG.isDebugEnabled())
LOG.debug("{} getAction({})", this, BufferUtil.toDetailString(buffer));
switch (state)
{
case READING:
return buffer.hasRemaining() ? Action.PARSE : Action.FILL;
case SUSPENDING:
this.buffer = buffer;
this.state = State.SUSPENDED;
return Action.SUSPEND;
case EOF:
return Action.EOF;
case DISCARDING:
return buffer.hasRemaining() ? Action.DISCARD : Action.FILL;
case SUSPENDED:
default:
throw new IllegalStateException(toString(state));
}
}
}
/** /**
* Requests that reads from the connection be suspended when {@link #suspend()} is called. * Requests that reads from the connection be suspended.
* *
* @return whether the suspending was successful * @return whether the suspending was successful
*/ */
@ -54,6 +96,9 @@ class ReadState
{ {
synchronized (this) synchronized (this)
{ {
if (LOG.isDebugEnabled())
LOG.debug("suspending {}", state);
switch (state) switch (state)
{ {
case READING: case READING:
@ -67,52 +112,6 @@ class ReadState
} }
} }
public boolean suspendParse(ByteBuffer buffer)
{
synchronized (this)
{
switch (state)
{
case READING:
return false;
case SUSPENDING:
this.buffer = buffer;
this.state = State.SUSPENDED;
return true;
default:
throw new IllegalStateException(toString(state));
}
}
}
/**
* Suspends reads from the connection if {@link #suspending()} was called.
*
* @return whether reads from the connection should be suspended
*/
boolean suspend()
{
synchronized (this)
{
switch (state)
{
case READING:
return false;
case SUSPENDING:
state = State.SUSPENDED;
return true;
case SUSPENDED:
if (buffer == null)
throw new IllegalStateException();
return true;
case EOF:
return true;
default:
throw new IllegalStateException(toString(state));
}
}
}
/** /**
* @return a ByteBuffer to finish processing, or null if we should register fillInterested * @return a ByteBuffer to finish processing, or null if we should register fillInterested
* If return value is {@link BufferUtil#EMPTY_BUFFER} no action should be taken. * If return value is {@link BufferUtil#EMPTY_BUFFER} no action should be taken.
@ -121,18 +120,21 @@ class ReadState
{ {
synchronized (this) synchronized (this)
{ {
if (LOG.isDebugEnabled())
LOG.debug("resuming {}", state);
switch (state) switch (state)
{ {
case SUSPENDING: case SUSPENDING:
state = State.READING; state = State.READING;
return NO_ACTION; return null;
case SUSPENDED: case SUSPENDED:
state = State.READING; state = State.READING;
ByteBuffer bb = buffer; ByteBuffer bb = buffer;
buffer = null; buffer = null;
return bb; return bb;
case EOF: case EOF:
return NO_ACTION; return null;
default: default:
throw new IllegalStateException(toString(state)); throw new IllegalStateException(toString(state));
} }
@ -143,10 +145,36 @@ class ReadState
{ {
synchronized (this) synchronized (this)
{ {
if (LOG.isDebugEnabled())
LOG.debug("eof {}", state);
state = State.EOF; state = State.EOF;
} }
} }
public void discard()
{
synchronized (this)
{
if (LOG.isDebugEnabled())
LOG.debug("discard {}", state);
switch (state)
{
case READING:
case SUSPENDED:
case SUSPENDING:
state = State.DISCARDING;
break;
case DISCARDING:
case EOF:
default:
throw new IllegalStateException(toString(state));
}
}
}
private String toString(State state) private String toString(State state)
{ {
return String.format("%s@%x[%s]", getClass().getSimpleName(), hashCode(), state); return String.format("%s@%x[%s]", getClass().getSimpleName(), hashCode(), state);
@ -161,6 +189,15 @@ class ReadState
} }
} }
public enum Action
{
FILL,
PARSE,
DISCARD,
SUSPEND,
EOF
}
private enum State private enum State
{ {
/** /**
@ -178,6 +215,11 @@ class ReadState
*/ */
SUSPENDED, SUSPENDED,
/**
* Reading from connection and discarding bytes until EOF.
*/
DISCARDING,
/** /**
* Won't read from the connection (terminal state). * Won't read from the connection (terminal state).
*/ */

View File

@ -18,6 +18,9 @@
package org.eclipse.jetty.websocket.common.io; package org.eclipse.jetty.websocket.common.io;
import java.nio.ByteBuffer;
import org.eclipse.jetty.util.BufferUtil;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
@ -34,7 +37,7 @@ public class ReadStateTest
ReadState readState = new ReadState(); ReadState readState = new ReadState();
assertThat("Initially reading", readState.isReading(), is(true)); assertThat("Initially reading", readState.isReading(), is(true));
assertThat("No prior suspending", readState.suspend(), is(false)); assertThat("Action is reading", readState.getAction(BufferUtil.toBuffer("content")), is(ReadState.Action.PARSE));
assertThat("No prior suspending", readState.isSuspended(), is(false)); assertThat("No prior suspending", readState.isSuspended(), is(false));
assertThrows(IllegalStateException.class, readState::resume, "No suspending to resume"); assertThrows(IllegalStateException.class, readState::resume, "No suspending to resume");
@ -50,10 +53,8 @@ public class ReadStateTest
assertTrue(readState.suspending()); assertTrue(readState.suspending());
assertThat("Suspending doesn't take effect immediately", readState.isSuspended(), is(false)); assertThat("Suspending doesn't take effect immediately", readState.isSuspended(), is(false));
assertThat("Resume from suspending requires no followup", readState.resume(), is(ReadState.NO_ACTION)); assertNull(readState.resume());
assertThat("Resume from suspending requires no followup", readState.isSuspended(), is(false)); assertThat("Action is reading", readState.getAction(BufferUtil.toBuffer("content")), is(ReadState.Action.PARSE));
assertThat("Suspending was discarded", readState.suspend(), is(false));
assertThat("Suspending was discarded", readState.isSuspended(), is(false)); assertThat("Suspending was discarded", readState.isSuspended(), is(false));
} }
@ -66,10 +67,11 @@ public class ReadStateTest
assertThat(readState.suspending(), is(true)); assertThat(readState.suspending(), is(true));
assertThat("Suspending doesn't take effect immediately", readState.isSuspended(), is(false)); assertThat("Suspending doesn't take effect immediately", readState.isSuspended(), is(false));
assertThat("Suspended", readState.suspend(), is(true)); ByteBuffer content = BufferUtil.toBuffer("content");
assertThat(readState.getAction(content), is(ReadState.Action.SUSPEND));
assertThat("Suspended", readState.isSuspended(), is(true)); assertThat("Suspended", readState.isSuspended(), is(true));
assertNull(readState.resume(), "Resumed"); assertThat(readState.resume(), is(content));
assertThat("Resumed", readState.isSuspended(), is(false)); assertThat("Resumed", readState.isSuspended(), is(false));
} }
@ -77,19 +79,13 @@ public class ReadStateTest
public void testEof() public void testEof()
{ {
ReadState readState = new ReadState(); ReadState readState = new ReadState();
ByteBuffer content = BufferUtil.toBuffer("content");
readState.eof(); readState.eof();
assertThat(readState.isReading(), is(false)); assertThat(readState.isReading(), is(false));
assertThat(readState.isSuspended(), is(true)); assertThat(readState.isSuspended(), is(true));
assertThat(readState.suspend(), is(true));
assertThat(readState.suspending(), is(false)); assertThat(readState.suspending(), is(false));
assertThat(readState.isSuspended(), is(true)); assertThat(readState.getAction(content), is(ReadState.Action.EOF));
assertNull(readState.resume());
assertThat(readState.suspend(), is(true));
assertThat(readState.isSuspended(), is(true));
assertThat(readState.resume(), is(ReadState.NO_ACTION));
assertThat(readState.isSuspended(), is(true));
} }
} }