Issue #4538 - Rework MessageInputStream and MessageReader

Message reader now validates UTF8

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2020-02-17 16:13:00 +11:00
parent 6eccc7ebce
commit 2467d5a8c5
7 changed files with 256 additions and 68 deletions

View File

@ -26,7 +26,9 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
@ -36,6 +38,7 @@ import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTimeout;
public class MessageInputStreamTest
@ -166,7 +169,7 @@ public class MessageInputStreamTest
{
// wait for a little bit before sending input closed
TimeUnit.MILLISECONDS.sleep(400);
stream.close();
stream.accept(new Frame(OpCode.TEXT, true, BufferUtil.EMPTY_BUFFER), Callback.NOOP);
}
catch (Throwable t)
{
@ -177,11 +180,22 @@ public class MessageInputStreamTest
// Read byte from stream.
int b = stream.read();
// Should be a -1, indicating the end of the stream.
// Test it
// Should be a -1, indicating the end of the stream.
assertThat("Error when closing", hadError.get(), is(false));
assertThat("Initial byte (Should be EOF)", b, is(-1));
// Close the stream.
stream.close();
// Any frame content after stream is closed should be discarded, and the callback succeeded.
FutureCallback callback = new FutureCallback();
stream.accept(new Frame(OpCode.TEXT, true, BufferUtil.toBuffer("hello world")), callback);
callback.block(5, TimeUnit.SECONDS);
// Any read after the stream is closed leads to an IOException.
IOException error = assertThrows(IOException.class, stream::read);
assertThat(error.getMessage(), is("Closed"));
}
});
}

View File

@ -22,7 +22,6 @@ import java.lang.invoke.MethodHandle;
import java.util.concurrent.CompletableFuture;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
@ -114,7 +113,7 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink
typeSink = newSink(frame);
dispatchComplete = new CompletableFuture<>();
// Dispatch to end user function (will likely start with blocking for data/accept)
// Dispatch to end user function (will likely start with blocking for data/accept).
new Thread(() ->
{
try

View File

@ -22,9 +22,9 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil;
@ -43,8 +43,9 @@ public class MessageInputStream extends InputStream implements MessageSink
{
private static final Logger LOG = Log.getLogger(MessageInputStream.class);
private static final Entry EOF = new Entry(BufferUtil.EMPTY_BUFFER, Callback.NOOP);
private static final Entry CLOSED = new Entry(BufferUtil.EMPTY_BUFFER, Callback.NOOP);
private final BlockingArrayQueue<Entry> buffers = new BlockingArrayQueue<>();
private final AtomicBoolean closed = new AtomicBoolean(false);
private boolean closed = false;
private Entry currentEntry;
private long timeoutMs = -1;
@ -54,20 +55,28 @@ public class MessageInputStream extends InputStream implements MessageSink
if (LOG.isDebugEnabled())
LOG.debug("accepting {}", frame);
// If closed or we have no payload, request the next frame.
if (closed.get() || (!frame.hasPayload() && !frame.isFin()))
boolean succeed = false;
synchronized (this)
{
callback.succeeded();
return;
// If closed or we have no payload, request the next frame.
if (closed || (!frame.hasPayload() && !frame.isFin()))
{
succeed = true;
}
else
{
if (frame.hasPayload())
buffers.add(new Entry(frame.getPayload(), callback));
else
succeed = true;
if (frame.isFin())
buffers.add(EOF);
}
}
if (frame.hasPayload())
buffers.add(new Entry(frame.getPayload(), callback));
else
if (succeed)
callback.succeeded();
if (frame.isFin())
buffers.add(EOF);
}
@Override
@ -88,31 +97,31 @@ public class MessageInputStream extends InputStream implements MessageSink
@Override
public int read(final byte[] b, final int off, final int len) throws IOException
{
if (closed.get())
return -1;
return read(ByteBuffer.wrap(b, off, len).flip());
}
public int read(ByteBuffer buffer) throws IOException
{
Entry result = getCurrentEntry();
if (LOG.isDebugEnabled())
LOG.debug("result = {}", result);
if (result == CLOSED)
throw new IOException("Closed");
if (result == EOF)
{
if (LOG.isDebugEnabled())
LOG.debug("Read EOF");
shutdown();
return -1;
}
// We have content
int fillLen = Math.min(result.buffer.remaining(), len);
result.buffer.get(b, off, fillLen);
// We have content.
int fillLen = BufferUtil.append(buffer, result.buffer);
if (!result.buffer.hasRemaining())
{
currentEntry = null;
result.callback.succeeded();
}
succeedCurrentEntry();
// return number of bytes actually copied into buffer
// Return number of bytes actually copied into buffer.
return fillLen;
}
@ -122,72 +131,89 @@ public class MessageInputStream extends InputStream implements MessageSink
if (LOG.isDebugEnabled())
LOG.debug("close()");
if (closed.compareAndSet(false, true))
ArrayList<Entry> failedEntries = new ArrayList<>();
synchronized (this)
{
synchronized (buffers)
if (closed)
return;
closed = true;
if (currentEntry != null)
{
buffers.offer(EOF);
buffers.notify();
failedEntries.add(currentEntry);
currentEntry = null;
}
// Clear queue and fail all entries.
failedEntries.addAll(buffers);
buffers.clear();
buffers.offer(CLOSED);
}
Throwable cause = new IOException("Closed");
for (Entry e : failedEntries)
{
e.callback.failed(cause);
}
super.close();
}
private void shutdown()
{
if (LOG.isDebugEnabled())
LOG.debug("shutdown()");
synchronized (this)
{
closed.set(true);
Throwable cause = new IOException("Shutdown");
for (Entry buffer : buffers)
{
buffer.callback.failed(cause);
}
// Removed buffers that may have remained in the queue.
buffers.clear();
}
}
public void setTimeout(long timeoutMs)
{
this.timeoutMs = timeoutMs;
}
private void succeedCurrentEntry()
{
Entry current;
synchronized (this)
{
current = currentEntry;
currentEntry = null;
}
if (current != null)
current.callback.succeeded();
}
private Entry getCurrentEntry() throws IOException
{
if (currentEntry != null)
return currentEntry;
synchronized (this)
{
if (currentEntry != null)
return currentEntry;
}
// sync and poll queue
try
{
if (LOG.isDebugEnabled())
LOG.debug("Waiting {} ms to read", timeoutMs);
Entry result;
if (timeoutMs < 0)
{
// Wait forever until a buffer is available.
currentEntry = buffers.take();
result = buffers.take();
}
else
{
// Wait at most for the given timeout.
currentEntry = buffers.poll(timeoutMs, TimeUnit.MILLISECONDS);
if (currentEntry == null)
result = buffers.poll(timeoutMs, TimeUnit.MILLISECONDS);
if (result == null)
throw new IOException(String.format("Read timeout: %,dms expired", timeoutMs));
}
synchronized (this)
{
currentEntry = result;
return currentEntry;
}
}
catch (InterruptedException e)
{
shutdown();
close();
throw new InterruptedIOException();
}
return currentEntry;
}
private static class Entry

View File

@ -18,35 +18,76 @@
package org.eclipse.jetty.websocket.util.messages;
import java.io.InputStreamReader;
import java.io.IOException;
import java.io.Reader;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.WebSocketConstants;
import static java.nio.charset.StandardCharsets.UTF_8;
/**
* Support class for reading a (single) WebSocket TEXT message via a Reader.
* <p>
* In compliance to the WebSocket spec, this reader always uses the {@link StandardCharsets#UTF_8}.
*/
public class MessageReader extends InputStreamReader implements MessageSink
public class MessageReader extends Reader implements MessageSink
{
private static final int BUFFER_SIZE = WebSocketConstants.DEFAULT_INPUT_BUFFER_SIZE;
private final ByteBuffer buffer;
private final MessageInputStream stream;
private final CharsetDecoder utf8Decoder = UTF_8.newDecoder()
.onUnmappableCharacter(CodingErrorAction.REPORT)
.onMalformedInput(CodingErrorAction.REPORT);
public MessageReader()
{
this(new MessageInputStream());
this(BUFFER_SIZE);
}
private MessageReader(MessageInputStream inputStream)
public MessageReader(int bufferSize)
{
super(inputStream, StandardCharsets.UTF_8);
this.stream = inputStream;
this.stream = new MessageInputStream();
this.buffer = BufferUtil.allocate(bufferSize);
}
@Override
public int read(char[] cbuf, int off, int len) throws IOException
{
CharBuffer charBuffer = CharBuffer.wrap(cbuf, off, len);
if (!buffer.hasRemaining())
{
int read = stream.read(buffer);
if (read == 0)
return read;
if (read < 0)
{
utf8Decoder.decode(BufferUtil.EMPTY_BUFFER, charBuffer, true);
return (charBuffer.position() > 0) ? charBuffer.position() : read;
}
}
utf8Decoder.decode(buffer, charBuffer, false);
return charBuffer.position();
}
@Override
public void close() throws IOException
{
stream.close();
}
@Override
public void accept(Frame frame, Callback callback)
{
this.stream.accept(frame, callback);
stream.accept(frame, callback);
}
}

View File

@ -25,6 +25,7 @@ import java.nio.charset.CharsetEncoder;
import java.nio.charset.CodingErrorAction;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.OpCode;
@ -66,4 +67,9 @@ public class MessageWriter extends Writer
{
outputStream.close();
}
public void setCallback(Callback callback)
{
outputStream.setCallback(callback);
}
}

View File

@ -33,6 +33,6 @@ public class ReaderMessageSink extends DispatchedMessageSink
@Override
public MessageReader newSink(Frame frame)
{
return new MessageReader();
return new MessageReader(session.getInputBufferSize());
}
}

View File

@ -0,0 +1,102 @@
package org.eclipse.jetty.websocket.util;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.util.messages.MessageReader;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class MessageReaderTest
{
private final MessageReader reader = new MessageReader();
private final CompletableFuture<String> message = new CompletableFuture<>();
private boolean first = true;
@BeforeEach
public void before()
{
// Read the message in a different thread.
new Thread(() ->
{
try
{
message.complete(IO.toString(reader));
}
catch (IOException e)
{
message.completeExceptionally(e);
}
}).start();
}
@Test
public void testSingleFrameMessage() throws Exception
{
giveString("hello world!", true);
String s = message.get(5, TimeUnit.SECONDS);
assertThat(s, is("hello world!"));
}
@Test
public void testFragmentedMessage() throws Exception
{
giveString("hello", false);
giveString(" ", false);
giveString("world", false);
giveString("!", true);
String s = message.get(5, TimeUnit.SECONDS);
assertThat(s, is("hello world!"));
}
@Test
public void testEmptySegments() throws Exception
{
giveString("", false);
giveString("hello ", false);
giveString("", false);
giveString("", false);
giveString("world!", false);
giveString("", false);
giveString("", true);
String s = message.get(5, TimeUnit.SECONDS);
assertThat(s, is("hello world!"));
}
@Test
public void testCloseStream() throws Exception
{
giveString("hello ", false);
reader.close();
giveString("world!", true);
ExecutionException error = assertThrows(ExecutionException.class, () -> message.get(5, TimeUnit.SECONDS));
Throwable cause = error.getCause();
assertThat(cause, instanceOf(IOException.class));
assertThat(cause.getMessage(), is("Closed"));
}
private void giveString(String s, boolean last) throws IOException
{
byte opCode = first ? OpCode.TEXT : OpCode.CONTINUATION;
Frame frame = new Frame(opCode, last, s);
FutureCallback callback = new FutureCallback();
reader.accept(frame, callback);
callback.block(5, TimeUnit.SECONDS);
first = false;
}
}