This closes #2895
This commit is contained in:
commit
c1c6a73ce5
|
@ -0,0 +1,67 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.activemq.artemis.core.server.protocol.websocket;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class uses the maximum frame payload size to packetize/frame outbound websocket messages into
|
||||||
|
* continuation frames.
|
||||||
|
*/
|
||||||
|
public class WebSocketFrameEncoder extends ChannelOutboundHandlerAdapter {
|
||||||
|
|
||||||
|
private int maxFramePayloadLength;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param maxFramePayloadLength
|
||||||
|
*/
|
||||||
|
public WebSocketFrameEncoder(int maxFramePayloadLength) {
|
||||||
|
this.maxFramePayloadLength = maxFramePayloadLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||||
|
if (msg instanceof ByteBuf) {
|
||||||
|
writeContinuationFrame(ctx, (ByteBuf) msg, promise);
|
||||||
|
} else {
|
||||||
|
super.write(ctx, msg, promise);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeContinuationFrame(ChannelHandlerContext ctx, ByteBuf byteBuf, ChannelPromise promise) {
|
||||||
|
int count = byteBuf.readableBytes();
|
||||||
|
int length = Math.min(count, maxFramePayloadLength);
|
||||||
|
boolean finalFragment = length == count;
|
||||||
|
ByteBuf fragment = Unpooled.buffer(length);
|
||||||
|
byteBuf.readBytes(fragment, length);
|
||||||
|
ctx.writeAndFlush(new BinaryWebSocketFrame(finalFragment, 0, fragment), promise);
|
||||||
|
|
||||||
|
while ((count = byteBuf.readableBytes()) > 0) {
|
||||||
|
length = Math.min(count, maxFramePayloadLength);
|
||||||
|
finalFragment = length == count;
|
||||||
|
fragment = Unpooled.buffer(length);
|
||||||
|
byteBuf.readBytes(fragment, length);
|
||||||
|
ctx.writeAndFlush(new ContinuationWebSocketFrame(finalFragment, 0, fragment), promise);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,12 +19,9 @@ package org.apache.activemq.artemis.core.server.protocol.websocket;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
|
||||||
import io.netty.channel.ChannelFuture;
|
import io.netty.channel.ChannelFuture;
|
||||||
import io.netty.channel.ChannelFutureListener;
|
import io.netty.channel.ChannelFutureListener;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
|
||||||
import io.netty.channel.ChannelPromise;
|
|
||||||
import io.netty.channel.SimpleChannelInboundHandler;
|
import io.netty.channel.SimpleChannelInboundHandler;
|
||||||
import io.netty.handler.codec.http.DefaultFullHttpResponse;
|
import io.netty.handler.codec.http.DefaultFullHttpResponse;
|
||||||
import io.netty.handler.codec.http.FullHttpRequest;
|
import io.netty.handler.codec.http.FullHttpRequest;
|
||||||
|
@ -33,6 +30,7 @@ import io.netty.handler.codec.http.HttpHeaderNames;
|
||||||
import io.netty.handler.codec.http.HttpRequest;
|
import io.netty.handler.codec.http.HttpRequest;
|
||||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
|
||||||
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
|
||||||
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
|
||||||
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
|
||||||
|
@ -55,7 +53,6 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object>
|
||||||
private WebSocketServerHandshaker handshaker;
|
private WebSocketServerHandshaker handshaker;
|
||||||
private List<String> supportedProtocols;
|
private List<String> supportedProtocols;
|
||||||
private int maxFramePayloadLength;
|
private int maxFramePayloadLength;
|
||||||
private static final BinaryWebSocketEncoder BINARY_WEBSOCKET_ENCODER = new BinaryWebSocketEncoder();
|
|
||||||
|
|
||||||
public WebSocketServerHandler(List<String> supportedProtocols, int maxFramePayloadLength) {
|
public WebSocketServerHandler(List<String> supportedProtocols, int maxFramePayloadLength) {
|
||||||
this.supportedProtocols = supportedProtocols;
|
this.supportedProtocols = supportedProtocols;
|
||||||
|
@ -98,7 +95,8 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object>
|
||||||
if (future.isSuccess()) {
|
if (future.isSuccess()) {
|
||||||
// we need to insert an encoder that takes the underlying ChannelBuffer of a StompFrame.toActiveMQBuffer and
|
// we need to insert an encoder that takes the underlying ChannelBuffer of a StompFrame.toActiveMQBuffer and
|
||||||
// wrap it in a binary web socket frame before letting the wsencoder send it on the wire
|
// wrap it in a binary web socket frame before letting the wsencoder send it on the wire
|
||||||
future.channel().pipeline().addAfter("wsencoder", "binary-websocket-encoder", BINARY_WEBSOCKET_ENCODER);
|
WebSocketFrameEncoder encoder = new WebSocketFrameEncoder(maxFramePayloadLength);
|
||||||
|
future.channel().pipeline().addAfter("wsencoder", "websocket-frame-encoder", encoder);
|
||||||
} else {
|
} else {
|
||||||
// Handshake failed, fire an exceptionCaught event
|
// Handshake failed, fire an exceptionCaught event
|
||||||
future.channel().pipeline().fireExceptionCaught(future.cause());
|
future.channel().pipeline().fireExceptionCaught(future.cause());
|
||||||
|
@ -117,7 +115,7 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object>
|
||||||
} else if (frame instanceof PingWebSocketFrame) {
|
} else if (frame instanceof PingWebSocketFrame) {
|
||||||
ctx.writeAndFlush(new PongWebSocketFrame(frame.content().retain()));
|
ctx.writeAndFlush(new PongWebSocketFrame(frame.content().retain()));
|
||||||
return false;
|
return false;
|
||||||
} else if (!(frame instanceof TextWebSocketFrame) && !(frame instanceof BinaryWebSocketFrame)) {
|
} else if (!(frame instanceof TextWebSocketFrame) && !(frame instanceof BinaryWebSocketFrame) && !(frame instanceof ContinuationWebSocketFrame)) {
|
||||||
throw new UnsupportedOperationException(String.format("%s frame types not supported", frame.getClass().getName()));
|
throw new UnsupportedOperationException(String.format("%s frame types not supported", frame.getClass().getName()));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -150,18 +148,4 @@ public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object>
|
||||||
public HttpRequest getHttpRequest() {
|
public HttpRequest getHttpRequest() {
|
||||||
return this.httpRequest;
|
return this.httpRequest;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Sharable
|
|
||||||
private static final class BinaryWebSocketEncoder extends ChannelOutboundHandlerAdapter {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
|
||||||
if (msg instanceof ByteBuf) {
|
|
||||||
msg = new BinaryWebSocketFrame((ByteBuf) msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.write(msg, promise);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,142 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.activemq.artemis.core.server.protocol.websocket;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
import static org.mockito.Mockito.eq;
|
||||||
|
import static org.mockito.Mockito.spy;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||||
|
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||||
|
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* WebSocketContinuationFrameEncoderTest
|
||||||
|
*/
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
|
public class WebSocketFrameEncoderTest {
|
||||||
|
|
||||||
|
private int maxFramePayloadLength = 100;
|
||||||
|
private WebSocketFrameEncoder spy;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private ChannelHandlerContext ctx;
|
||||||
|
@Mock
|
||||||
|
private ChannelPromise promise;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() throws Exception {
|
||||||
|
spy = spy(new WebSocketFrameEncoder(maxFramePayloadLength));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWriteNonByteBuf() throws Exception {
|
||||||
|
Object msg = "Not a ByteBuf";
|
||||||
|
|
||||||
|
spy.write(ctx, msg, promise); //test
|
||||||
|
|
||||||
|
verify(spy).write(ctx, msg, promise);
|
||||||
|
verify(ctx).write(msg, promise);
|
||||||
|
verifyNoMoreInteractions(spy, ctx);
|
||||||
|
verifyZeroInteractions(promise);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWriteSingleFrame() throws Exception {
|
||||||
|
String content = "Content MSG length less than max frame payload length: " + maxFramePayloadLength;
|
||||||
|
ByteBuf msg = Unpooled.copiedBuffer(content, StandardCharsets.UTF_8);
|
||||||
|
ArgumentCaptor<WebSocketFrame> frameCaptor = ArgumentCaptor.forClass(WebSocketFrame.class);
|
||||||
|
|
||||||
|
spy.write(ctx, msg, promise); //test
|
||||||
|
|
||||||
|
assertEquals(0, msg.readableBytes());
|
||||||
|
verify(ctx).writeAndFlush(frameCaptor.capture(), eq(promise));
|
||||||
|
WebSocketFrame frame = frameCaptor.getValue();
|
||||||
|
assertTrue(frame instanceof BinaryWebSocketFrame);
|
||||||
|
assertTrue(frame.isFinalFragment());
|
||||||
|
assertEquals(content, frame.content().toString(StandardCharsets.UTF_8));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWriteContinuationFrames() throws Exception {
|
||||||
|
String contentPart = "Content MSG Length @ ";
|
||||||
|
StringBuilder contentBuilder = new StringBuilder(3 * maxFramePayloadLength);
|
||||||
|
|
||||||
|
while (contentBuilder.length() < 2 * maxFramePayloadLength) {
|
||||||
|
contentBuilder.append(contentPart);
|
||||||
|
contentBuilder.append(contentBuilder.length());
|
||||||
|
contentBuilder.append('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
String content = contentBuilder.toString();
|
||||||
|
int length = content.length();
|
||||||
|
assertTrue(length > 2 * maxFramePayloadLength); //at least 3 frames of data
|
||||||
|
ByteBuf msg = Unpooled.copiedBuffer(content, StandardCharsets.UTF_8);
|
||||||
|
ArgumentCaptor<WebSocketFrame> frameCaptor = ArgumentCaptor.forClass(WebSocketFrame.class);
|
||||||
|
|
||||||
|
spy.write(ctx, msg, promise); //test
|
||||||
|
|
||||||
|
assertEquals(0, msg.readableBytes());
|
||||||
|
verify(spy).write(ctx, msg, promise);
|
||||||
|
verify(ctx, times(3)).writeAndFlush(frameCaptor.capture(), eq(promise));
|
||||||
|
List<WebSocketFrame> frames = frameCaptor.getAllValues();
|
||||||
|
assertEquals(3, frames.size());
|
||||||
|
|
||||||
|
int offset = 0;
|
||||||
|
WebSocketFrame first = frames.get(0);
|
||||||
|
assertTrue(first instanceof BinaryWebSocketFrame);
|
||||||
|
assertFalse(first.isFinalFragment());
|
||||||
|
assertEquals(content.substring(offset, offset + maxFramePayloadLength),
|
||||||
|
first.content().toString(StandardCharsets.UTF_8));
|
||||||
|
|
||||||
|
offset += maxFramePayloadLength;
|
||||||
|
WebSocketFrame second = frames.get(1);
|
||||||
|
assertTrue(second instanceof ContinuationWebSocketFrame);
|
||||||
|
assertFalse(second.isFinalFragment());
|
||||||
|
assertEquals(content.substring(offset, offset + maxFramePayloadLength),
|
||||||
|
second.content().toString(StandardCharsets.UTF_8));
|
||||||
|
|
||||||
|
offset += maxFramePayloadLength;
|
||||||
|
WebSocketFrame last = frames.get(2);
|
||||||
|
assertTrue(last instanceof ContinuationWebSocketFrame);
|
||||||
|
assertTrue(last.isFinalFragment());
|
||||||
|
assertEquals(content.substring(offset), last.content().toString(StandardCharsets.UTF_8));
|
||||||
|
|
||||||
|
verifyNoMoreInteractions(spy, ctx);
|
||||||
|
verifyZeroInteractions(promise);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.activemq.artemis.core.server.protocol.websocket;
|
||||||
|
|
||||||
|
import static org.mockito.Mockito.any;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.spy;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* WebSocketServerHandlerTest
|
||||||
|
*/
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
|
public class WebSocketServerHandlerTest {
|
||||||
|
|
||||||
|
private int maxFramePayloadLength;
|
||||||
|
private List<String> supportedProtocols;
|
||||||
|
private WebSocketServerHandler spy;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() throws Exception {
|
||||||
|
maxFramePayloadLength = 8192;
|
||||||
|
supportedProtocols = Arrays.asList("STOMP");
|
||||||
|
spy = spy(new WebSocketServerHandler(supportedProtocols, maxFramePayloadLength));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRead0HandleContinuationFrame() throws Exception {
|
||||||
|
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
|
||||||
|
Object msg = new ContinuationWebSocketFrame();
|
||||||
|
|
||||||
|
spy.channelRead0(ctx, msg); //test
|
||||||
|
|
||||||
|
verify(spy).channelRead0(ctx, msg);
|
||||||
|
verify(ctx).fireChannelRead(any(ByteBuf.class));
|
||||||
|
verifyNoMoreInteractions(spy, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRead0HandleBinaryFrame() throws Exception {
|
||||||
|
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
|
||||||
|
Object msg = new BinaryWebSocketFrame();
|
||||||
|
|
||||||
|
spy.channelRead0(ctx, msg); //test
|
||||||
|
|
||||||
|
verify(spy).channelRead0(ctx, msg);
|
||||||
|
verify(ctx).fireChannelRead(any(ByteBuf.class));
|
||||||
|
verifyNoMoreInteractions(spy, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRead0HandleTextFrame() throws Exception {
|
||||||
|
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
|
||||||
|
Object msg = new TextWebSocketFrame();
|
||||||
|
|
||||||
|
spy.channelRead0(ctx, msg); //test
|
||||||
|
|
||||||
|
verify(spy).channelRead0(ctx, msg);
|
||||||
|
verify(ctx).fireChannelRead(any(ByteBuf.class));
|
||||||
|
verifyNoMoreInteractions(spy, ctx);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue