Issue #6566 - use counter in BufferCallbackAccumulator, fix InputStreamMessageSinkTest failures

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2021-08-18 17:10:12 +10:00
parent 65ff0bb081
commit 026261f482
6 changed files with 101 additions and 78 deletions

View File

@ -24,6 +24,7 @@ import org.eclipse.jetty.util.Callback;
public class BufferCallbackAccumulator
{
private final List<Entry> _entries = new ArrayList<>();
private int _length;
private static class Entry
{
@ -40,6 +41,7 @@ public class BufferCallbackAccumulator
public void addEntry(ByteBuffer buffer, Callback callback)
{
_entries.add(new Entry(buffer, callback));
_length = Math.addExact(_length, buffer.remaining());
}
/**
@ -49,16 +51,13 @@ public class BufferCallbackAccumulator
*/
public int getLength()
{
int length = 0;
for (Entry entry : _entries)
length = Math.addExact(length, entry.buffer.remaining());
return length;
return _length;
}
/**
* @return a newly allocated byte array containing all content written into the accumulator.
*/
public byte[] toByteArray()
public byte[] takeByteArray()
{
int length = getLength();
if (length == 0)
@ -76,10 +75,16 @@ public class BufferCallbackAccumulator
for (Iterator<Entry> iterator = _entries.iterator(); iterator.hasNext();)
{
Entry entry = iterator.next();
_length = entry.buffer.remaining();
buffer.put(entry.buffer);
iterator.remove();
entry.callback.succeeded();
}
if (!_entries.isEmpty())
throw new IllegalStateException("remaining entries: " + _entries.size());
if (_length != 0)
throw new IllegalStateException("non-zero length: " + _length);
}
public void fail(Throwable t)
@ -89,5 +94,6 @@ public class BufferCallbackAccumulator
entry.callback.failed(t);
}
_entries.clear();
_length = 0;
}
}

View File

@ -53,8 +53,8 @@ public class ByteArrayMessageSink extends AbstractMessageSink
long maxBinaryMessageSize = session.getMaxBinaryMessageSize();
if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize)
{
callback.failed(new MessageTooLargeException(
String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", size, maxBinaryMessageSize)));
throw new MessageTooLargeException(
String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", size, maxBinaryMessageSize));
}
// If we are fin and no OutputStream has been created we don't need to aggregate.
@ -73,13 +73,20 @@ public class ByteArrayMessageSink extends AbstractMessageSink
return;
}
aggregatePayload(frame, callback);
// Aggregate the frame payload.
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
if (out == null)
out = new BufferCallbackAccumulator();
out.addEntry(payload, callback);
}
// If the methodHandle throws we don't want to fail callback twice.
callback = Callback.NOOP;
if (frame.isFin())
{
byte[] buf = out.toByteArray();
byte[] buf = out.takeByteArray();
methodHandle.invoke(buf, 0, buf.length);
}
@ -101,15 +108,4 @@ public class ByteArrayMessageSink extends AbstractMessageSink
}
}
}
private void aggregatePayload(Frame frame, Callback callback)
{
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
if (out == null)
out = new BufferCallbackAccumulator();
out.addEntry(payload, callback);
}
}
}

View File

@ -71,7 +71,14 @@ public class ByteBufferMessageSink extends AbstractMessageSink
return;
}
aggregatePayload(frame, callback);
// Aggregate the frame payload.
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
if (out == null)
out = new BufferCallbackAccumulator();
out.addEntry(payload, callback);
}
// If the methodHandle throws we don't want to fail callback twice.
callback = Callback.NOOP;
@ -110,15 +117,4 @@ public class ByteBufferMessageSink extends AbstractMessageSink
}
}
}
private void aggregatePayload(Frame frame, Callback callback)
{
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
if (out == null)
out = new BufferCallbackAccumulator();
out.addEntry(payload, callback);
}
}
}

View File

@ -13,48 +13,35 @@
package org.eclipse.jetty.websocket.javax.common;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Session;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.NullByteBufferPool;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.WebSocketComponents;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import static org.junit.jupiter.api.Assertions.assertTrue;
public abstract class AbstractSessionTest
{
protected static JavaxWebSocketSession session;
protected static JavaxWebSocketContainer container;
protected static WebSocketComponents components;
protected static JavaxWebSocketContainer container = new DummyContainer();
protected static WebSocketComponents components = new WebSocketComponents();
protected static TestCoreSession coreSession = new TestCoreSession();
@BeforeAll
public static void initSession() throws Exception
{
container = new DummyContainer();
container.start();
components = new WebSocketComponents();
components.start();
Object websocketPojo = new DummyEndpoint();
UpgradeRequest upgradeRequest = new UpgradeRequestAdapter();
JavaxWebSocketFrameHandler frameHandler = container.newFrameHandler(websocketPojo, upgradeRequest);
ByteBufferPool bufferPool = new NullByteBufferPool();
CoreSession coreSession = new CoreSession.Empty()
{
@Override
public WebSocketComponents getWebSocketComponents()
{
return components;
}
@Override
public ByteBufferPool getByteBufferPool()
{
return bufferPool;
}
};
session = new JavaxWebSocketSession(container, coreSession, frameHandler, container.getFrameHandlerFactory()
.newDefaultEndpointConfig(websocketPojo.getClass()));
}
@ -66,6 +53,34 @@ public abstract class AbstractSessionTest
container.stop();
}
public static class TestCoreSession extends CoreSession.Empty
{
private final Semaphore demand = new Semaphore(0);
@Override
public WebSocketComponents getWebSocketComponents()
{
return components;
}
@Override
public ByteBufferPool getByteBufferPool()
{
return components.getBufferPool();
}
public void waitForDemand(long timeout, TimeUnit timeUnit) throws InterruptedException
{
assertTrue(demand.tryAcquire(timeout, timeUnit));
}
@Override
public void demand(long n)
{
demand.release();
}
}
public static class DummyEndpoint extends Endpoint
{
@Override

View File

@ -20,15 +20,12 @@ import java.util.function.Consumer;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.Decoder;
import org.eclipse.jetty.websocket.core.WebSocketComponents;
import org.eclipse.jetty.websocket.javax.common.AbstractSessionTest;
import org.eclipse.jetty.websocket.javax.common.JavaxWebSocketFrameHandlerFactory;
import org.eclipse.jetty.websocket.javax.common.decoders.RegisteredDecoder;
public abstract class AbstractMessageSinkTest extends AbstractSessionTest
{
private final WebSocketComponents _components = new WebSocketComponents();
public List<RegisteredDecoder> toRegisteredDecoderList(Class<? extends Decoder> clazz, Class<?> objectType)
{
Class<? extends Decoder> interfaceType;
@ -43,7 +40,7 @@ public abstract class AbstractMessageSinkTest extends AbstractSessionTest
else
throw new IllegalStateException();
return List.of(new RegisteredDecoder(clazz, interfaceType, objectType, ClientEndpointConfig.Builder.create().build(), _components));
return List.of(new RegisteredDecoder(clazz, interfaceType, objectType, ClientEndpointConfig.Builder.create().build(), components));
}
public <T> MethodHandle getAcceptHandle(Consumer<T> copy, Class<T> type)

View File

@ -18,6 +18,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
@ -51,10 +52,11 @@ public class InputStreamMessageSinkTest extends AbstractMessageSinkTest
ByteBuffer data = BufferUtil.toBuffer("Hello World", UTF_8);
sink.accept(new Frame(OpCode.BINARY).setPayload(data), finCallback);
finCallback.get(1, TimeUnit.SECONDS); // wait for callback
coreSession.waitForDemand(1, TimeUnit.SECONDS);
finCallback.get(1, TimeUnit.SECONDS);
ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS);
assertThat("FinCallback.done", finCallback.isDone(), is(true));
assertThat("Writer.contents", new String(byteStream.toByteArray(), UTF_8), is("Hello World"));
assertThat("Writer.contents", byteStream.toString(UTF_8), is("Hello World"));
}
@Test
@ -68,19 +70,22 @@ public class InputStreamMessageSinkTest extends AbstractMessageSinkTest
ByteBuffer data1 = BufferUtil.toBuffer("Hello World", UTF_8);
sink.accept(new Frame(OpCode.BINARY).setPayload(data1).setFin(true), fin1Callback);
fin1Callback.get(1, TimeUnit.SECONDS); // wait for callback (can't sent next message until this callback finishes)
// wait for demand (can't sent next message until a new frame is demanded)
coreSession.waitForDemand(1, TimeUnit.SECONDS);
fin1Callback.get(1, TimeUnit.SECONDS);
ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS);
assertThat("FinCallback.done", fin1Callback.isDone(), is(true));
assertThat("Writer.contents", new String(byteStream.toByteArray(), UTF_8), is("Hello World"));
assertThat("Writer.contents", byteStream.toString(UTF_8), is("Hello World"));
FutureCallback fin2Callback = new FutureCallback();
ByteBuffer data2 = BufferUtil.toBuffer("Greetings Earthling", UTF_8);
sink.accept(new Frame(OpCode.BINARY).setPayload(data2).setFin(true), fin2Callback);
fin2Callback.get(1, TimeUnit.SECONDS); // wait for callback
coreSession.waitForDemand(1, TimeUnit.SECONDS);
fin2Callback.get(1, TimeUnit.SECONDS);
byteStream = copy.poll(1, TimeUnit.SECONDS);
assertThat("FinCallback.done", fin2Callback.isDone(), is(true));
assertThat("Writer.contents", new String(byteStream.toByteArray(), UTF_8), is("Greetings Earthling"));
assertThat("Writer.contents", byteStream.toString(UTF_8), is("Greetings Earthling"));
}
@Test
@ -95,16 +100,19 @@ public class InputStreamMessageSinkTest extends AbstractMessageSinkTest
FutureCallback finCallback = new FutureCallback();
sink.accept(new Frame(OpCode.BINARY).setPayload("Hello").setFin(false), callback1);
sink.accept(new Frame(OpCode.CONTINUATION).setPayload(", ").setFin(false), callback2);
sink.accept(new Frame(OpCode.CONTINUATION).setPayload("World").setFin(true), finCallback);
coreSession.waitForDemand(1, TimeUnit.SECONDS);
assertThat("callback1.done", callback1.isDone(), is(true));
finCallback.get(1, TimeUnit.SECONDS); // wait for callback
ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS);
assertThat("Callback1.done", callback1.isDone(), is(true));
assertThat("Callback2.done", callback2.isDone(), is(true));
sink.accept(new Frame(OpCode.CONTINUATION).setPayload(", ").setFin(false), callback2);
coreSession.waitForDemand(1, TimeUnit.SECONDS);
assertThat("callback2.done", callback2.isDone(), is(true));
sink.accept(new Frame(OpCode.CONTINUATION).setPayload("World").setFin(true), finCallback);
coreSession.waitForDemand(1, TimeUnit.SECONDS);
assertThat("finCallback.done", finCallback.isDone(), is(true));
assertThat("Writer.contents", new String(byteStream.toByteArray(), UTF_8), is("Hello, World"));
ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS);
assertThat("Writer.contents", byteStream.toString(UTF_8), is("Hello, World"));
}
@Test
@ -120,18 +128,23 @@ public class InputStreamMessageSinkTest extends AbstractMessageSinkTest
FutureCallback finCallback = new FutureCallback();
sink.accept(new Frame(OpCode.BINARY).setPayload("Greetings").setFin(false), callback1);
sink.accept(new Frame(OpCode.CONTINUATION).setPayload(", ").setFin(false), callback2);
sink.accept(new Frame(OpCode.CONTINUATION).setPayload("Earthling").setFin(false), callback3);
sink.accept(new Frame(OpCode.CONTINUATION).setPayload(new byte[0]).setFin(true), finCallback);
finCallback.get(5, TimeUnit.SECONDS); // wait for callback
ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS);
coreSession.waitForDemand(1, TimeUnit.SECONDS);
assertThat("Callback1.done", callback1.isDone(), is(true));
sink.accept(new Frame(OpCode.CONTINUATION).setPayload(", ").setFin(false), callback2);
coreSession.waitForDemand(1, TimeUnit.SECONDS);
assertThat("Callback2.done", callback2.isDone(), is(true));
sink.accept(new Frame(OpCode.CONTINUATION).setPayload("Earthling").setFin(false), callback3);
coreSession.waitForDemand(1, TimeUnit.SECONDS);
assertThat("Callback3.done", callback3.isDone(), is(true));
sink.accept(new Frame(OpCode.CONTINUATION).setPayload(new byte[0]).setFin(true), finCallback);
coreSession.waitForDemand(1, TimeUnit.SECONDS);
assertThat("finCallback.done", finCallback.isDone(), is(true));
assertThat("Writer.contents", new String(byteStream.toByteArray(), UTF_8), is("Greetings, Earthling"));
ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS);
assertThat("Writer.contents", byteStream.toString(UTF_8), is("Greetings, Earthling"));
}
public static class InputStreamCopy implements Consumer<InputStream>
@ -156,9 +169,9 @@ public class InputStreamMessageSinkTest extends AbstractMessageSinkTest
}
}
public ByteArrayOutputStream poll(long time, TimeUnit unit) throws InterruptedException, ExecutionException
public ByteArrayOutputStream poll(long time, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException
{
return streams.poll(time, unit).get();
return Objects.requireNonNull(streams.poll(time, unit)).get(time, unit);
}
}
}