diff --git a/artemis-server/src/main/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketFrameEncoder.java b/artemis-server/src/main/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketFrameEncoder.java new file mode 100644 index 0000000000..703998b5b3 --- /dev/null +++ b/artemis-server/src/main/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketFrameEncoder.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.artemis.core.server.protocol.websocket; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; + +/** + * This class uses the maximum frame payload size to packetize/frame outbound websocket messages into + * continuation frames. + */ +public class WebSocketFrameEncoder extends ChannelOutboundHandlerAdapter { + + private int maxFramePayloadLength; + + /** + * @param maxFramePayloadLength + */ + public WebSocketFrameEncoder(int maxFramePayloadLength) { + this.maxFramePayloadLength = maxFramePayloadLength; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof ByteBuf) { + writeContinuationFrame(ctx, (ByteBuf) msg, promise); + } else { + super.write(ctx, msg, promise); + } + } + + private void writeContinuationFrame(ChannelHandlerContext ctx, ByteBuf byteBuf, ChannelPromise promise) { + int count = byteBuf.readableBytes(); + int length = Math.min(count, maxFramePayloadLength); + boolean finalFragment = length == count; + ByteBuf fragment = Unpooled.buffer(length); + byteBuf.readBytes(fragment, length); + ctx.writeAndFlush(new BinaryWebSocketFrame(finalFragment, 0, fragment), promise); + + while ((count = byteBuf.readableBytes()) > 0) { + length = Math.min(count, maxFramePayloadLength); + finalFragment = length == count; + fragment = Unpooled.buffer(length); + byteBuf.readBytes(fragment, length); + ctx.writeAndFlush(new ContinuationWebSocketFrame(finalFragment, 0, fragment), promise); + } + } +} diff --git a/artemis-server/src/main/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketServerHandler.java b/artemis-server/src/main/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketServerHandler.java index f764985eea..83a8d0e431 100644 --- a/artemis-server/src/main/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketServerHandler.java +++ b/artemis-server/src/main/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketServerHandler.java @@ -19,12 +19,9 @@ package org.apache.activemq.artemis.core.server.protocol.websocket; import java.nio.charset.StandardCharsets; import java.util.List; -import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelOutboundHandlerAdapter; -import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; @@ -33,6 +30,7 @@ import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; @@ -55,7 +53,6 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler private WebSocketServerHandshaker handshaker; private List supportedProtocols; private int maxFramePayloadLength; - private static final BinaryWebSocketEncoder BINARY_WEBSOCKET_ENCODER = new BinaryWebSocketEncoder(); public WebSocketServerHandler(List supportedProtocols, int maxFramePayloadLength) { this.supportedProtocols = supportedProtocols; @@ -98,7 +95,8 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler if (future.isSuccess()) { // we need to insert an encoder that takes the underlying ChannelBuffer of a StompFrame.toActiveMQBuffer and // wrap it in a binary web socket frame before letting the wsencoder send it on the wire - future.channel().pipeline().addAfter("wsencoder", "binary-websocket-encoder", BINARY_WEBSOCKET_ENCODER); + WebSocketFrameEncoder encoder = new WebSocketFrameEncoder(maxFramePayloadLength); + future.channel().pipeline().addAfter("wsencoder", "websocket-frame-encoder", encoder); } else { // Handshake failed, fire an exceptionCaught event future.channel().pipeline().fireExceptionCaught(future.cause()); @@ -117,7 +115,7 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler } else if (frame instanceof PingWebSocketFrame) { ctx.writeAndFlush(new PongWebSocketFrame(frame.content().retain())); return false; - } else if (!(frame instanceof TextWebSocketFrame) && !(frame instanceof BinaryWebSocketFrame)) { + } else if (!(frame instanceof TextWebSocketFrame) && !(frame instanceof BinaryWebSocketFrame) && !(frame instanceof ContinuationWebSocketFrame)) { throw new UnsupportedOperationException(String.format("%s frame types not supported", frame.getClass().getName())); } return true; @@ -150,18 +148,4 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler public HttpRequest getHttpRequest() { return this.httpRequest; } - - @Sharable - private static final class BinaryWebSocketEncoder extends ChannelOutboundHandlerAdapter { - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - if (msg instanceof ByteBuf) { - msg = new BinaryWebSocketFrame((ByteBuf) msg); - } - - ctx.write(msg, promise); - } - - } } diff --git a/artemis-server/src/test/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketFrameEncoderTest.java b/artemis-server/src/test/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketFrameEncoderTest.java new file mode 100644 index 0000000000..fc67c9a005 --- /dev/null +++ b/artemis-server/src/test/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketFrameEncoderTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.artemis.core.server.protocol.websocket; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; + +import java.nio.charset.StandardCharsets; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; + +/** + * WebSocketContinuationFrameEncoderTest + */ +@RunWith(MockitoJUnitRunner.class) +public class WebSocketFrameEncoderTest { + + private int maxFramePayloadLength = 100; + private WebSocketFrameEncoder spy; + + @Mock + private ChannelHandlerContext ctx; + @Mock + private ChannelPromise promise; + + @Before + public void setUp() throws Exception { + spy = spy(new WebSocketFrameEncoder(maxFramePayloadLength)); + } + + @Test + public void testWriteNonByteBuf() throws Exception { + Object msg = "Not a ByteBuf"; + + spy.write(ctx, msg, promise); //test + + verify(spy).write(ctx, msg, promise); + verify(ctx).write(msg, promise); + verifyNoMoreInteractions(spy, ctx); + verifyZeroInteractions(promise); + } + + @Test + public void testWriteSingleFrame() throws Exception { + String content = "Content MSG length less than max frame payload length: " + maxFramePayloadLength; + ByteBuf msg = Unpooled.copiedBuffer(content, StandardCharsets.UTF_8); + ArgumentCaptor frameCaptor = ArgumentCaptor.forClass(WebSocketFrame.class); + + spy.write(ctx, msg, promise); //test + + assertEquals(0, msg.readableBytes()); + verify(ctx).writeAndFlush(frameCaptor.capture(), eq(promise)); + WebSocketFrame frame = frameCaptor.getValue(); + assertTrue(frame instanceof BinaryWebSocketFrame); + assertTrue(frame.isFinalFragment()); + assertEquals(content, frame.content().toString(StandardCharsets.UTF_8)); + } + + @Test + public void testWriteContinuationFrames() throws Exception { + String contentPart = "Content MSG Length @ "; + StringBuilder contentBuilder = new StringBuilder(3 * maxFramePayloadLength); + + while (contentBuilder.length() < 2 * maxFramePayloadLength) { + contentBuilder.append(contentPart); + contentBuilder.append(contentBuilder.length()); + contentBuilder.append('\n'); + } + + String content = contentBuilder.toString(); + int length = content.length(); + assertTrue(length > 2 * maxFramePayloadLength); //at least 3 frames of data + ByteBuf msg = Unpooled.copiedBuffer(content, StandardCharsets.UTF_8); + ArgumentCaptor frameCaptor = ArgumentCaptor.forClass(WebSocketFrame.class); + + spy.write(ctx, msg, promise); //test + + assertEquals(0, msg.readableBytes()); + verify(spy).write(ctx, msg, promise); + verify(ctx, times(3)).writeAndFlush(frameCaptor.capture(), eq(promise)); + List frames = frameCaptor.getAllValues(); + assertEquals(3, frames.size()); + + int offset = 0; + WebSocketFrame first = frames.get(0); + assertTrue(first instanceof BinaryWebSocketFrame); + assertFalse(first.isFinalFragment()); + assertEquals(content.substring(offset, offset + maxFramePayloadLength), + first.content().toString(StandardCharsets.UTF_8)); + + offset += maxFramePayloadLength; + WebSocketFrame second = frames.get(1); + assertTrue(second instanceof ContinuationWebSocketFrame); + assertFalse(second.isFinalFragment()); + assertEquals(content.substring(offset, offset + maxFramePayloadLength), + second.content().toString(StandardCharsets.UTF_8)); + + offset += maxFramePayloadLength; + WebSocketFrame last = frames.get(2); + assertTrue(last instanceof ContinuationWebSocketFrame); + assertTrue(last.isFinalFragment()); + assertEquals(content.substring(offset), last.content().toString(StandardCharsets.UTF_8)); + + verifyNoMoreInteractions(spy, ctx); + verifyZeroInteractions(promise); + } +} diff --git a/artemis-server/src/test/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketServerHandlerTest.java b/artemis-server/src/test/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketServerHandlerTest.java new file mode 100644 index 0000000000..a7578d7777 --- /dev/null +++ b/artemis-server/src/test/java/org/apache/activemq/artemis/core/server/protocol/websocket/WebSocketServerHandlerTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.artemis.core.server.protocol.websocket; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import java.util.Arrays; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; + +/** + * WebSocketServerHandlerTest + */ +@RunWith(MockitoJUnitRunner.class) +public class WebSocketServerHandlerTest { + + private int maxFramePayloadLength; + private List supportedProtocols; + private WebSocketServerHandler spy; + + @Before + public void setup() throws Exception { + maxFramePayloadLength = 8192; + supportedProtocols = Arrays.asList("STOMP"); + spy = spy(new WebSocketServerHandler(supportedProtocols, maxFramePayloadLength)); + } + + @Test + public void testRead0HandleContinuationFrame() throws Exception { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + Object msg = new ContinuationWebSocketFrame(); + + spy.channelRead0(ctx, msg); //test + + verify(spy).channelRead0(ctx, msg); + verify(ctx).fireChannelRead(any(ByteBuf.class)); + verifyNoMoreInteractions(spy, ctx); + } + + @Test + public void testRead0HandleBinaryFrame() throws Exception { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + Object msg = new BinaryWebSocketFrame(); + + spy.channelRead0(ctx, msg); //test + + verify(spy).channelRead0(ctx, msg); + verify(ctx).fireChannelRead(any(ByteBuf.class)); + verifyNoMoreInteractions(spy, ctx); + } + + @Test + public void testRead0HandleTextFrame() throws Exception { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + Object msg = new TextWebSocketFrame(); + + spy.channelRead0(ctx, msg); //test + + verify(spy).channelRead0(ctx, msg); + verify(ctx).fireChannelRead(any(ByteBuf.class)); + verifyNoMoreInteractions(spy, ctx); + } +}