diff --git a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java index f1e1b5bd40..989b1d8b44 100644 --- a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java +++ b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java @@ -31,6 +31,7 @@ public class StompCodec { final static byte[] crlfcrlf = new byte[]{'\r','\n','\r','\n'}; TcpTransport transport; + StompWireFormat wireFormat; ByteArrayOutputStream currentCommand = new ByteArrayOutputStream(); boolean processedHeaders = false; @@ -44,6 +45,7 @@ public class StompCodec { public StompCodec(TcpTransport transport) { this.transport = transport; + this.wireFormat = (StompWireFormat) transport.getWireFormat(); } public void parse(ByteArrayInputStream input, int readSize) throws Exception { @@ -68,18 +70,20 @@ public class StompCodec { currentCommand.write(b); // end of headers section, parse action and header if (b == '\n' && (previousByte == '\n' || currentCommand.endsWith(crlfcrlf))) { - StompWireFormat wf = (StompWireFormat) transport.getWireFormat(); DataByteArrayInputStream data = new DataByteArrayInputStream(currentCommand.toByteArray()); - action = wf.parseAction(data); - headers = wf.parseHeaders(data); + action = wireFormat.parseAction(data); + headers = wireFormat.parseHeaders(data); try { String contentLengthHeader = headers.get(Stomp.Headers.CONTENT_LENGTH); if ((action.equals(Stomp.Commands.SEND) || action.equals(Stomp.Responses.MESSAGE)) && contentLengthHeader != null) { - contentLength = wf.parseContentLength(contentLengthHeader); + contentLength = wireFormat.parseContentLength(contentLengthHeader); } else { contentLength = -1; } - } catch (ProtocolException ignore) {} + } catch (ProtocolException e) { + transport.doConsume(new StompFrameError(e)); + return; + } processedHeaders = true; currentCommand.reset(); } @@ -92,6 +96,10 @@ public class StompCodec { processCommand(); } else { currentCommand.write(b); + if (currentCommand.size() > wireFormat.getMaxDataLength()) { + transport.doConsume(new StompFrameError(new ProtocolException("The maximum data length was exceeded", true))); + return; + } } } else { // read desired content length diff --git a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompNIOTransport.java b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompNIOTransport.java index 2d760d0013..e2be9f9357 100644 --- a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompNIOTransport.java +++ b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompNIOTransport.java @@ -40,8 +40,6 @@ import org.apache.activemq.wireformat.WireFormat; /** * An implementation of the {@link Transport} interface for using Stomp over NIO - * - * */ public class StompNIOTransport extends TcpTransport { @@ -59,16 +57,19 @@ public class StompNIOTransport extends TcpTransport { super(wireFormat, socket); } + @Override protected void initializeStreams() throws IOException { channel = socket.getChannel(); channel.configureBlocking(false); // listen for events telling us when the socket is readable. selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { + @Override public void onSelect(SelectorSelection selection) { serviceRead(); } + @Override public void onError(SelectorSelection selection, Throwable error) { if (error instanceof IOException) { onException((IOException)error); @@ -120,12 +121,14 @@ public class StompNIOTransport extends TcpTransport { } } + @Override protected void doStart() throws Exception { connect(); selection.setInterestOps(SelectionKey.OP_READ); selection.enable(); } + @Override protected void doStop(ServiceStopper stopper) throws Exception { try { if (selection != null) { diff --git a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java index 14c8122266..9cf003ed27 100644 --- a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java +++ b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java @@ -153,4 +153,19 @@ public class StompTransportFilter extends TransportFilter implements StompTransp } } + /** + * Sets the maximum number of bytes that the data portion of a STOMP frame is allowed to + * be, any incoming STOMP frame with a data section larger than this value will receive + * an error response. + * + * @param maxDataLength + * size in bytes of the maximum data portion of a STOMP frame. + */ + public void setMaxDataLength(int maxDataLength) { + wireFormat.setMaxDataLength(maxDataLength); + } + + public int getMaxDataLength() { + return wireFormat.getMaxDataLength(); + } } diff --git a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java index 1d38bdb0a2..1a95443131 100644 --- a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java +++ b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java @@ -46,8 +46,10 @@ public class StompWireFormat implements WireFormat { private static final int MAX_DATA_LENGTH = 1024 * 1024 * 100; private int version = 1; + private int maxDataLength = MAX_DATA_LENGTH; private String stompVersion = Stomp.DEFAULT_VERSION; + @Override public ByteSequence marshal(Object command) throws IOException { ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputStream dos = new DataOutputStream(baos); @@ -56,12 +58,14 @@ public class StompWireFormat implements WireFormat { return baos.toByteSequence(); } + @Override public Object unmarshal(ByteSequence packet) throws IOException { ByteArrayInputStream stream = new ByteArrayInputStream(packet); DataInputStream dis = new DataInputStream(stream); return unmarshal(dis); } + @Override public void marshal(Object command, DataOutput os) throws IOException { StompFrame stomp = (org.apache.activemq.transport.stomp.StompFrame)command; @@ -90,6 +94,7 @@ public class StompWireFormat implements WireFormat { os.write(END_OF_FRAME); } + @Override public Object unmarshal(DataInput in) throws IOException { try { @@ -124,7 +129,7 @@ public class StompWireFormat implements WireFormat { if (baos == null) { baos = new ByteArrayOutputStream(); - } else if (baos.size() > MAX_DATA_LENGTH) { + } else if (baos.size() > getMaxDataLength()) { throw new ProtocolException("The maximum data length was exceeded", true); } @@ -249,7 +254,7 @@ public class StompWireFormat implements WireFormat { throw new ProtocolException("Specified content-length is not a valid integer", true); } - if (length > MAX_DATA_LENGTH) { + if (length > getMaxDataLength()) { throw new ProtocolException("The maximum data length was exceeded", true); } @@ -277,6 +282,7 @@ public class StompWireFormat implements WireFormat { } } result = new String(stream.toByteArray(), "UTF-8"); + stream.close(); } return result; @@ -315,13 +321,17 @@ public class StompWireFormat implements WireFormat { } } + decoded.close(); + return new String(decoded.toByteArray(), "UTF-8"); } + @Override public int getVersion() { return version; } + @Override public void setVersion(int version) { this.version = version; } @@ -333,4 +343,12 @@ public class StompWireFormat implements WireFormat { public void setStompVersion(String stompVersion) { this.stompVersion = stompVersion; } + + public void setMaxDataLength(int maxDataLength) { + this.maxDataLength = maxDataLength; + } + + public int getMaxDataLength() { + return maxDataLength; + } } diff --git a/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompMaxDataSizeTest.java b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompMaxDataSizeTest.java new file mode 100644 index 0000000000..338d76a6e1 --- /dev/null +++ b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompMaxDataSizeTest.java @@ -0,0 +1,163 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.stomp; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.net.Socket; +import java.util.Arrays; + +import javax.net.SocketFactory; +import javax.net.ssl.SSLSocketFactory; + +import org.apache.activemq.broker.TransportConnector; +import org.junit.Test; + +public class StompMaxDataSizeTest extends StompTestSupport { + + private static final int TEST_MAX_DATA_SIZE = 64 * 1024; + + private StompConnection connection; + + @Override + public void setUp() throws Exception { + System.setProperty("javax.net.ssl.trustStore", "src/test/resources/client.keystore"); + System.setProperty("javax.net.ssl.trustStorePassword", "password"); + System.setProperty("javax.net.ssl.trustStoreType", "jks"); + System.setProperty("javax.net.ssl.keyStore", "src/test/resources/server.keystore"); + System.setProperty("javax.net.ssl.keyStorePassword", "password"); + System.setProperty("javax.net.ssl.keyStoreType", "jks"); + super.setUp(); + } + + @Override + public void tearDown() throws Exception { + if (connection != null) { + try { + connection.close(); + } catch (Throwable ex) {} + } + super.tearDown(); + } + + @Override + protected void addStompConnector() throws Exception { + TransportConnector connector = null; + + connector = brokerService.addConnector("stomp+ssl://0.0.0.0:"+ sslPort + + "?transport.maxDataLength=" + TEST_MAX_DATA_SIZE); + sslPort = connector.getConnectUri().getPort(); + connector = brokerService.addConnector("stomp://0.0.0.0:" + port + + "?transport.maxDataLength=" + TEST_MAX_DATA_SIZE); + port = connector.getConnectUri().getPort(); + connector = brokerService.addConnector("stomp+nio://0.0.0.0:" + nioPort + + "?transport.maxDataLength=" + TEST_MAX_DATA_SIZE); + nioPort = connector.getConnectUri().getPort(); + connector = brokerService.addConnector("stomp+nio+ssl://0.0.0.0:" + nioSslPort + + "?transport.maxDataLength=" + TEST_MAX_DATA_SIZE); + nioSslPort = connector.getConnectUri().getPort(); + } + + @Test(timeout = 60000) + public void testOversizedMessageOnPlainSocket() throws Exception { + doTestOversizedMessage(port, false); + } + + @Test(timeout = 60000) + public void testOversizedMessageOnNioSocket() throws Exception { + doTestOversizedMessage(nioPort, false); + } + + @Test//(timeout = 60000) + public void testOversizedMessageOnSslSocket() throws Exception { + doTestOversizedMessage(sslPort, true); + } + + @Test(timeout = 60000) + public void testOversizedMessageOnNioSslSocket() throws Exception { + doTestOversizedMessage(nioSslPort, true); + } + + protected void doTestOversizedMessage(int port, boolean useSsl) throws Exception { + stompConnect(port, useSsl); + + String frame = "CONNECT\n" + "login:system\n" + "passcode:manager\n\n" + Stomp.NULL; + stompConnection.sendFrame(frame); + + frame = stompConnection.receiveFrame(); + assertTrue(frame.startsWith("CONNECTED")); + + frame = "SUBSCRIBE\n" + "destination:/queue/" + getQueueName() + "\n" + "ack:auto\n\n" + Stomp.NULL; + stompConnection.sendFrame(frame); + + int size = 100; + char[] bigBodyArray = new char[size]; + Arrays.fill(bigBodyArray, 'a'); + String bigBody = new String(bigBodyArray); + + frame = "SEND\n" + "destination:/queue/" + getQueueName() + "\n\n" + bigBody + Stomp.NULL; + + stompConnection.sendFrame(frame); + + StompFrame received = stompConnection.receive(); + assertNotNull(received); + assertEquals("MESSAGE", received.getAction()); + assertEquals(bigBody, received.getBody()); + + size = TEST_MAX_DATA_SIZE + 100; + bigBodyArray = new char[size]; + Arrays.fill(bigBodyArray, 'a'); + bigBody = new String(bigBodyArray); + + frame = "SEND\n" + "destination:/queue/" + getQueueName() + "\n\n" + bigBody + Stomp.NULL; + + stompConnection.sendFrame(frame); + + received = stompConnection.receive(5000); + assertNotNull(received); + assertEquals("ERROR", received.getAction()); + } + + protected StompConnection stompConnect(int port, boolean ssl) throws Exception { + if (stompConnection == null) { + stompConnection = new StompConnection(); + } + + Socket socket = null; + if (ssl) { + socket = createSslSocket(port); + } else { + socket = createSocket(port); + } + + stompConnection.open(socket); + + return stompConnection; + } + + protected Socket createSocket(int port) throws IOException { + return new Socket("127.0.0.1", port); + } + + protected Socket createSslSocket(int port) throws IOException { + SocketFactory factory = SSLSocketFactory.getDefault(); + return factory.createSocket("127.0.0.1", port); + } +}