Allow configuration of max size of STOMP command body via
transportConnector option transport.maxDataLength
This commit is contained in:
Timothy Bish 2015-02-11 13:41:04 -05:00
parent 07338e7553
commit 8d4234345b
5 changed files with 216 additions and 9 deletions

View File

@ -31,6 +31,7 @@ public class StompCodec {
final static byte[] crlfcrlf = new byte[]{'\r','\n','\r','\n'};
TcpTransport transport;
StompWireFormat wireFormat;
ByteArrayOutputStream currentCommand = new ByteArrayOutputStream();
boolean processedHeaders = false;
@ -44,6 +45,7 @@ public class StompCodec {
public StompCodec(TcpTransport transport) {
this.transport = transport;
this.wireFormat = (StompWireFormat) transport.getWireFormat();
}
public void parse(ByteArrayInputStream input, int readSize) throws Exception {
@ -68,18 +70,20 @@ public class StompCodec {
currentCommand.write(b);
// end of headers section, parse action and header
if (b == '\n' && (previousByte == '\n' || currentCommand.endsWith(crlfcrlf))) {
StompWireFormat wf = (StompWireFormat) transport.getWireFormat();
DataByteArrayInputStream data = new DataByteArrayInputStream(currentCommand.toByteArray());
action = wf.parseAction(data);
headers = wf.parseHeaders(data);
action = wireFormat.parseAction(data);
headers = wireFormat.parseHeaders(data);
try {
String contentLengthHeader = headers.get(Stomp.Headers.CONTENT_LENGTH);
if ((action.equals(Stomp.Commands.SEND) || action.equals(Stomp.Responses.MESSAGE)) && contentLengthHeader != null) {
contentLength = wf.parseContentLength(contentLengthHeader);
contentLength = wireFormat.parseContentLength(contentLengthHeader);
} else {
contentLength = -1;
}
} catch (ProtocolException ignore) {}
} catch (ProtocolException e) {
transport.doConsume(new StompFrameError(e));
return;
}
processedHeaders = true;
currentCommand.reset();
}
@ -92,6 +96,10 @@ public class StompCodec {
processCommand();
} else {
currentCommand.write(b);
if (currentCommand.size() > wireFormat.getMaxDataLength()) {
transport.doConsume(new StompFrameError(new ProtocolException("The maximum data length was exceeded", true)));
return;
}
}
} else {
// read desired content length

View File

@ -40,8 +40,6 @@ import org.apache.activemq.wireformat.WireFormat;
/**
* An implementation of the {@link Transport} interface for using Stomp over NIO
*
*
*/
public class StompNIOTransport extends TcpTransport {
@ -59,16 +57,19 @@ public class StompNIOTransport extends TcpTransport {
super(wireFormat, socket);
}
@Override
protected void initializeStreams() throws IOException {
channel = socket.getChannel();
channel.configureBlocking(false);
// listen for events telling us when the socket is readable.
selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
@Override
public void onSelect(SelectorSelection selection) {
serviceRead();
}
@Override
public void onError(SelectorSelection selection, Throwable error) {
if (error instanceof IOException) {
onException((IOException)error);
@ -120,12 +121,14 @@ public class StompNIOTransport extends TcpTransport {
}
}
@Override
protected void doStart() throws Exception {
connect();
selection.setInterestOps(SelectionKey.OP_READ);
selection.enable();
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
try {
if (selection != null) {

View File

@ -153,4 +153,19 @@ public class StompTransportFilter extends TransportFilter implements StompTransp
}
}
/**
* Sets the maximum number of bytes that the data portion of a STOMP frame is allowed to
* be, any incoming STOMP frame with a data section larger than this value will receive
* an error response.
*
* @param maxDataLength
* size in bytes of the maximum data portion of a STOMP frame.
*/
public void setMaxDataLength(int maxDataLength) {
wireFormat.setMaxDataLength(maxDataLength);
}
public int getMaxDataLength() {
return wireFormat.getMaxDataLength();
}
}

View File

@ -46,8 +46,10 @@ public class StompWireFormat implements WireFormat {
private static final int MAX_DATA_LENGTH = 1024 * 1024 * 100;
private int version = 1;
private int maxDataLength = MAX_DATA_LENGTH;
private String stompVersion = Stomp.DEFAULT_VERSION;
@Override
public ByteSequence marshal(Object command) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
@ -56,12 +58,14 @@ public class StompWireFormat implements WireFormat {
return baos.toByteSequence();
}
@Override
public Object unmarshal(ByteSequence packet) throws IOException {
ByteArrayInputStream stream = new ByteArrayInputStream(packet);
DataInputStream dis = new DataInputStream(stream);
return unmarshal(dis);
}
@Override
public void marshal(Object command, DataOutput os) throws IOException {
StompFrame stomp = (org.apache.activemq.transport.stomp.StompFrame)command;
@ -90,6 +94,7 @@ public class StompWireFormat implements WireFormat {
os.write(END_OF_FRAME);
}
@Override
public Object unmarshal(DataInput in) throws IOException {
try {
@ -124,7 +129,7 @@ public class StompWireFormat implements WireFormat {
if (baos == null) {
baos = new ByteArrayOutputStream();
} else if (baos.size() > MAX_DATA_LENGTH) {
} else if (baos.size() > getMaxDataLength()) {
throw new ProtocolException("The maximum data length was exceeded", true);
}
@ -249,7 +254,7 @@ public class StompWireFormat implements WireFormat {
throw new ProtocolException("Specified content-length is not a valid integer", true);
}
if (length > MAX_DATA_LENGTH) {
if (length > getMaxDataLength()) {
throw new ProtocolException("The maximum data length was exceeded", true);
}
@ -277,6 +282,7 @@ public class StompWireFormat implements WireFormat {
}
}
result = new String(stream.toByteArray(), "UTF-8");
stream.close();
}
return result;
@ -315,13 +321,17 @@ public class StompWireFormat implements WireFormat {
}
}
decoded.close();
return new String(decoded.toByteArray(), "UTF-8");
}
@Override
public int getVersion() {
return version;
}
@Override
public void setVersion(int version) {
this.version = version;
}
@ -333,4 +343,12 @@ public class StompWireFormat implements WireFormat {
public void setStompVersion(String stompVersion) {
this.stompVersion = stompVersion;
}
public void setMaxDataLength(int maxDataLength) {
this.maxDataLength = maxDataLength;
}
public int getMaxDataLength() {
return maxDataLength;
}
}

View File

@ -0,0 +1,163 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.stomp;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.net.Socket;
import java.util.Arrays;
import javax.net.SocketFactory;
import javax.net.ssl.SSLSocketFactory;
import org.apache.activemq.broker.TransportConnector;
import org.junit.Test;
public class StompMaxDataSizeTest extends StompTestSupport {
private static final int TEST_MAX_DATA_SIZE = 64 * 1024;
private StompConnection connection;
@Override
public void setUp() throws Exception {
System.setProperty("javax.net.ssl.trustStore", "src/test/resources/client.keystore");
System.setProperty("javax.net.ssl.trustStorePassword", "password");
System.setProperty("javax.net.ssl.trustStoreType", "jks");
System.setProperty("javax.net.ssl.keyStore", "src/test/resources/server.keystore");
System.setProperty("javax.net.ssl.keyStorePassword", "password");
System.setProperty("javax.net.ssl.keyStoreType", "jks");
super.setUp();
}
@Override
public void tearDown() throws Exception {
if (connection != null) {
try {
connection.close();
} catch (Throwable ex) {}
}
super.tearDown();
}
@Override
protected void addStompConnector() throws Exception {
TransportConnector connector = null;
connector = brokerService.addConnector("stomp+ssl://0.0.0.0:"+ sslPort +
"?transport.maxDataLength=" + TEST_MAX_DATA_SIZE);
sslPort = connector.getConnectUri().getPort();
connector = brokerService.addConnector("stomp://0.0.0.0:" + port +
"?transport.maxDataLength=" + TEST_MAX_DATA_SIZE);
port = connector.getConnectUri().getPort();
connector = brokerService.addConnector("stomp+nio://0.0.0.0:" + nioPort +
"?transport.maxDataLength=" + TEST_MAX_DATA_SIZE);
nioPort = connector.getConnectUri().getPort();
connector = brokerService.addConnector("stomp+nio+ssl://0.0.0.0:" + nioSslPort +
"?transport.maxDataLength=" + TEST_MAX_DATA_SIZE);
nioSslPort = connector.getConnectUri().getPort();
}
@Test(timeout = 60000)
public void testOversizedMessageOnPlainSocket() throws Exception {
doTestOversizedMessage(port, false);
}
@Test(timeout = 60000)
public void testOversizedMessageOnNioSocket() throws Exception {
doTestOversizedMessage(nioPort, false);
}
@Test//(timeout = 60000)
public void testOversizedMessageOnSslSocket() throws Exception {
doTestOversizedMessage(sslPort, true);
}
@Test(timeout = 60000)
public void testOversizedMessageOnNioSslSocket() throws Exception {
doTestOversizedMessage(nioSslPort, true);
}
protected void doTestOversizedMessage(int port, boolean useSsl) throws Exception {
stompConnect(port, useSsl);
String frame = "CONNECT\n" + "login:system\n" + "passcode:manager\n\n" + Stomp.NULL;
stompConnection.sendFrame(frame);
frame = stompConnection.receiveFrame();
assertTrue(frame.startsWith("CONNECTED"));
frame = "SUBSCRIBE\n" + "destination:/queue/" + getQueueName() + "\n" + "ack:auto\n\n" + Stomp.NULL;
stompConnection.sendFrame(frame);
int size = 100;
char[] bigBodyArray = new char[size];
Arrays.fill(bigBodyArray, 'a');
String bigBody = new String(bigBodyArray);
frame = "SEND\n" + "destination:/queue/" + getQueueName() + "\n\n" + bigBody + Stomp.NULL;
stompConnection.sendFrame(frame);
StompFrame received = stompConnection.receive();
assertNotNull(received);
assertEquals("MESSAGE", received.getAction());
assertEquals(bigBody, received.getBody());
size = TEST_MAX_DATA_SIZE + 100;
bigBodyArray = new char[size];
Arrays.fill(bigBodyArray, 'a');
bigBody = new String(bigBodyArray);
frame = "SEND\n" + "destination:/queue/" + getQueueName() + "\n\n" + bigBody + Stomp.NULL;
stompConnection.sendFrame(frame);
received = stompConnection.receive(5000);
assertNotNull(received);
assertEquals("ERROR", received.getAction());
}
protected StompConnection stompConnect(int port, boolean ssl) throws Exception {
if (stompConnection == null) {
stompConnection = new StompConnection();
}
Socket socket = null;
if (ssl) {
socket = createSslSocket(port);
} else {
socket = createSocket(port);
}
stompConnection.open(socket);
return stompConnection;
}
protected Socket createSocket(int port) throws IOException {
return new Socket("127.0.0.1", port);
}
protected Socket createSslSocket(int port) throws IOException {
SocketFactory factory = SSLSocketFactory.getDefault();
return factory.createSocket("127.0.0.1", port);
}
}