ARTEMIS-2550 Support Websocket Continuation Frames

Large messages can be split up using Websocket Continuation Frames.
This allows for much smaller buffer sizes to send or receive
potentially very large messages.
This commit is contained in:
Dewald Pretorius 2019-11-12 06:07:52 +02:00 committed by Justin Bertram
parent a2604f09c0
commit 9fac4b866c
4 changed files with 304 additions and 20 deletions

View File

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.artemis.core.server.protocol.websocket;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
/**
* This class uses the maximum frame payload size to packetize/frame outbound websocket messages into
* continuation frames.
*/
public class WebSocketFrameEncoder extends ChannelOutboundHandlerAdapter {
private int maxFramePayloadLength;
/**
* @param maxFramePayloadLength
*/
public WebSocketFrameEncoder(int maxFramePayloadLength) {
this.maxFramePayloadLength = maxFramePayloadLength;
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof ByteBuf) {
writeContinuationFrame(ctx, (ByteBuf) msg, promise);
} else {
super.write(ctx, msg, promise);
}
}
private void writeContinuationFrame(ChannelHandlerContext ctx, ByteBuf byteBuf, ChannelPromise promise) {
int count = byteBuf.readableBytes();
int length = Math.min(count, maxFramePayloadLength);
boolean finalFragment = length == count;
ByteBuf fragment = Unpooled.buffer(length);
byteBuf.readBytes(fragment, length);
ctx.writeAndFlush(new BinaryWebSocketFrame(finalFragment, 0, fragment), promise);
while ((count = byteBuf.readableBytes()) > 0) {
length = Math.min(count, maxFramePayloadLength);
finalFragment = length == count;
fragment = Unpooled.buffer(length);
byteBuf.readBytes(fragment, length);
ctx.writeAndFlush(new ContinuationWebSocketFrame(finalFragment, 0, fragment), promise);
}
}
}

View File

@ -19,12 +19,9 @@ package org.apache.activemq.artemis.core.server.protocol.websocket;
import java.nio.charset.StandardCharsets;
import java.util.List;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
@ -33,6 +30,7 @@ import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
@ -55,7 +53,6 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object>
private WebSocketServerHandshaker handshaker;
private List<String> supportedProtocols;
private int maxFramePayloadLength;
private static final BinaryWebSocketEncoder BINARY_WEBSOCKET_ENCODER = new BinaryWebSocketEncoder();
public WebSocketServerHandler(List<String> supportedProtocols, int maxFramePayloadLength) {
this.supportedProtocols = supportedProtocols;
@ -98,7 +95,8 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object>
if (future.isSuccess()) {
// we need to insert an encoder that takes the underlying ChannelBuffer of a StompFrame.toActiveMQBuffer and
// wrap it in a binary web socket frame before letting the wsencoder send it on the wire
future.channel().pipeline().addAfter("wsencoder", "binary-websocket-encoder", BINARY_WEBSOCKET_ENCODER);
WebSocketFrameEncoder encoder = new WebSocketFrameEncoder(maxFramePayloadLength);
future.channel().pipeline().addAfter("wsencoder", "websocket-frame-encoder", encoder);
} else {
// Handshake failed, fire an exceptionCaught event
future.channel().pipeline().fireExceptionCaught(future.cause());
@ -117,7 +115,7 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object>
} else if (frame instanceof PingWebSocketFrame) {
ctx.writeAndFlush(new PongWebSocketFrame(frame.content().retain()));
return false;
} else if (!(frame instanceof TextWebSocketFrame) && !(frame instanceof BinaryWebSocketFrame)) {
} else if (!(frame instanceof TextWebSocketFrame) && !(frame instanceof BinaryWebSocketFrame) && !(frame instanceof ContinuationWebSocketFrame)) {
throw new UnsupportedOperationException(String.format("%s frame types not supported", frame.getClass().getName()));
}
return true;
@ -150,18 +148,4 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object>
public HttpRequest getHttpRequest() {
return this.httpRequest;
}
@Sharable
private static final class BinaryWebSocketEncoder extends ChannelOutboundHandlerAdapter {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof ByteBuf) {
msg = new BinaryWebSocketFrame((ByteBuf) msg);
}
ctx.write(msg, promise);
}
}
}

View File

@ -0,0 +1,142 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.artemis.core.server.protocol.websocket;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
import java.nio.charset.StandardCharsets;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
/**
* WebSocketContinuationFrameEncoderTest
*/
@RunWith(MockitoJUnitRunner.class)
public class WebSocketFrameEncoderTest {
private int maxFramePayloadLength = 100;
private WebSocketFrameEncoder spy;
@Mock
private ChannelHandlerContext ctx;
@Mock
private ChannelPromise promise;
@Before
public void setUp() throws Exception {
spy = spy(new WebSocketFrameEncoder(maxFramePayloadLength));
}
@Test
public void testWriteNonByteBuf() throws Exception {
Object msg = "Not a ByteBuf";
spy.write(ctx, msg, promise); //test
verify(spy).write(ctx, msg, promise);
verify(ctx).write(msg, promise);
verifyNoMoreInteractions(spy, ctx);
verifyZeroInteractions(promise);
}
@Test
public void testWriteSingleFrame() throws Exception {
String content = "Content MSG length less than max frame payload length: " + maxFramePayloadLength;
ByteBuf msg = Unpooled.copiedBuffer(content, StandardCharsets.UTF_8);
ArgumentCaptor<WebSocketFrame> frameCaptor = ArgumentCaptor.forClass(WebSocketFrame.class);
spy.write(ctx, msg, promise); //test
assertEquals(0, msg.readableBytes());
verify(ctx).writeAndFlush(frameCaptor.capture(), eq(promise));
WebSocketFrame frame = frameCaptor.getValue();
assertTrue(frame instanceof BinaryWebSocketFrame);
assertTrue(frame.isFinalFragment());
assertEquals(content, frame.content().toString(StandardCharsets.UTF_8));
}
@Test
public void testWriteContinuationFrames() throws Exception {
String contentPart = "Content MSG Length @ ";
StringBuilder contentBuilder = new StringBuilder(3 * maxFramePayloadLength);
while (contentBuilder.length() < 2 * maxFramePayloadLength) {
contentBuilder.append(contentPart);
contentBuilder.append(contentBuilder.length());
contentBuilder.append('\n');
}
String content = contentBuilder.toString();
int length = content.length();
assertTrue(length > 2 * maxFramePayloadLength); //at least 3 frames of data
ByteBuf msg = Unpooled.copiedBuffer(content, StandardCharsets.UTF_8);
ArgumentCaptor<WebSocketFrame> frameCaptor = ArgumentCaptor.forClass(WebSocketFrame.class);
spy.write(ctx, msg, promise); //test
assertEquals(0, msg.readableBytes());
verify(spy).write(ctx, msg, promise);
verify(ctx, times(3)).writeAndFlush(frameCaptor.capture(), eq(promise));
List<WebSocketFrame> frames = frameCaptor.getAllValues();
assertEquals(3, frames.size());
int offset = 0;
WebSocketFrame first = frames.get(0);
assertTrue(first instanceof BinaryWebSocketFrame);
assertFalse(first.isFinalFragment());
assertEquals(content.substring(offset, offset + maxFramePayloadLength),
first.content().toString(StandardCharsets.UTF_8));
offset += maxFramePayloadLength;
WebSocketFrame second = frames.get(1);
assertTrue(second instanceof ContinuationWebSocketFrame);
assertFalse(second.isFinalFragment());
assertEquals(content.substring(offset, offset + maxFramePayloadLength),
second.content().toString(StandardCharsets.UTF_8));
offset += maxFramePayloadLength;
WebSocketFrame last = frames.get(2);
assertTrue(last instanceof ContinuationWebSocketFrame);
assertTrue(last.isFinalFragment());
assertEquals(content.substring(offset), last.content().toString(StandardCharsets.UTF_8));
verifyNoMoreInteractions(spy, ctx);
verifyZeroInteractions(promise);
}
}

View File

@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.artemis.core.server.protocol.websocket;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import java.util.Arrays;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
/**
* WebSocketServerHandlerTest
*/
@RunWith(MockitoJUnitRunner.class)
public class WebSocketServerHandlerTest {
private int maxFramePayloadLength;
private List<String> supportedProtocols;
private WebSocketServerHandler spy;
@Before
public void setup() throws Exception {
maxFramePayloadLength = 8192;
supportedProtocols = Arrays.asList("STOMP");
spy = spy(new WebSocketServerHandler(supportedProtocols, maxFramePayloadLength));
}
@Test
public void testRead0HandleContinuationFrame() throws Exception {
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
Object msg = new ContinuationWebSocketFrame();
spy.channelRead0(ctx, msg); //test
verify(spy).channelRead0(ctx, msg);
verify(ctx).fireChannelRead(any(ByteBuf.class));
verifyNoMoreInteractions(spy, ctx);
}
@Test
public void testRead0HandleBinaryFrame() throws Exception {
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
Object msg = new BinaryWebSocketFrame();
spy.channelRead0(ctx, msg); //test
verify(spy).channelRead0(ctx, msg);
verify(ctx).fireChannelRead(any(ByteBuf.class));
verifyNoMoreInteractions(spy, ctx);
}
@Test
public void testRead0HandleTextFrame() throws Exception {
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
Object msg = new TextWebSocketFrame();
spy.channelRead0(ctx, msg); //test
verify(spy).channelRead0(ctx, msg);
verify(ctx).fireChannelRead(any(ByteBuf.class));
verifyNoMoreInteractions(spy, ctx);
}
}