Ensure that client's connecting with non-supported AMQP versions or
client's with invalid AMQP headers are sent an AMQP v1.0 header and are
then disconnected.
This commit is contained in:
Timothy Bish 2014-12-08 17:23:15 -05:00
parent f75857fbbf
commit 61a3eab8ab
11 changed files with 1070 additions and 225 deletions

View File

@ -0,0 +1,227 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.amqp;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.activemq.transport.amqp.AmqpWireFormat.ResetListener;
import org.apache.activemq.transport.tcp.TcpTransport;
import org.fusesource.hawtbuf.Buffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* State based Frame reader that is used in the NIO based transports where
* AMQP frames can come in in partial or overlapping forms.
*/
public class AmqpFrameParser {
private static final Logger LOG = LoggerFactory.getLogger(AmqpFrameParser.class);
public interface AMQPFrameSink {
void onFrame(Object frame);
}
private static final byte AMQP_FRAME_SIZE_BYTES = 4;
private static final byte AMQP_HEADER_BYTES = 8;
private final AMQPFrameSink frameSink;
private FrameParser currentParser;
private AmqpWireFormat wireFormat;
public AmqpFrameParser(AMQPFrameSink sink) {
this.frameSink = sink;
}
public AmqpFrameParser(final TcpTransport transport) {
this.frameSink = new AMQPFrameSink() {
@Override
public void onFrame(Object frame) {
transport.doConsume(frame);
}
};
}
public void parse(ByteBuffer incoming) throws Exception {
if (incoming == null || !incoming.hasRemaining()) {
return;
}
if (currentParser == null) {
currentParser = initializeHeaderParser();
}
// Parser stack will run until current incoming data has all been consumed.
currentParser.parse(incoming);
}
public void reset() {
currentParser = initializeHeaderParser();
}
private void validateFrameSize(int frameSize) throws IOException {
long maxFrameSize = AmqpWireFormat.DEFAULT_MAX_FRAME_SIZE;
if (wireFormat != null) {
maxFrameSize = wireFormat.getMaxFrameSize();
}
if (frameSize > maxFrameSize) {
throw new IOException("Frame size of " + frameSize + " larger than max allowed " + maxFrameSize);
}
}
public void setWireFormat(AmqpWireFormat wireFormat) {
this.wireFormat = wireFormat;
if (wireFormat != null) {
wireFormat.setProtocolResetListener(new ResetListener() {
@Override
public void onProtocolReset() {
reset();
}
});
}
}
public AmqpWireFormat getWireFormat() {
return this.wireFormat;
}
//----- Prepare the current frame parser for use -------------------------//
private FrameParser initializeHeaderParser() {
headerReader.reset(AMQP_HEADER_BYTES);
return headerReader;
}
private FrameParser initializeFrameLengthParser() {
frameSizeReader.reset(AMQP_FRAME_SIZE_BYTES);
return frameSizeReader;
}
private FrameParser initializeContentReader(int contentLength) {
contentReader.reset(contentLength);
return contentReader;
}
//----- Frame parser implementations -------------------------------------//
private interface FrameParser {
void parse(ByteBuffer incoming) throws IOException;
void reset(int nextExpectedReadSize);
}
private final FrameParser headerReader = new FrameParser() {
private final Buffer header = new Buffer(AMQP_HEADER_BYTES);
@Override
public void parse(ByteBuffer incoming) throws IOException {
int length = Math.min(incoming.remaining(), header.length - header.offset);
incoming.get(header.data, header.offset, length);
header.offset += length;
if (header.offset == AMQP_HEADER_BYTES) {
header.reset();
AmqpHeader amqpHeader = new AmqpHeader(header.deepCopy(), false);
currentParser = initializeFrameLengthParser();
frameSink.onFrame(amqpHeader);
if (incoming.hasRemaining()) {
currentParser.parse(incoming);
}
}
}
@Override
public void reset(int nextExpectedReadSize) {
header.reset();
}
};
private final FrameParser frameSizeReader = new FrameParser() {
private int frameSize;
private int multiplier;
@Override
public void parse(ByteBuffer incoming) throws IOException {
while (incoming.hasRemaining()) {
frameSize += ((incoming.get() & 0xFF) << --multiplier * Byte.SIZE);
if (multiplier == 0) {
LOG.trace("Next incoming frame length: {}", frameSize);
validateFrameSize(frameSize);
currentParser = initializeContentReader(frameSize);
if (incoming.hasRemaining()) {
currentParser.parse(incoming);
return;
}
}
}
}
@Override
public void reset(int nextExpectedReadSize) {
multiplier = AMQP_FRAME_SIZE_BYTES;
frameSize = 0;
}
};
private final FrameParser contentReader = new FrameParser() {
private Buffer frame;
@Override
public void parse(ByteBuffer incoming) throws IOException {
int length = Math.min(incoming.remaining(), frame.getLength() - frame.offset);
incoming.get(frame.data, frame.offset, length);
frame.offset += length;
if (frame.offset == frame.length) {
LOG.trace("Contents of size {} have been read", frame.length);
frame.reset();
frameSink.onFrame(frame);
if (currentParser == this) {
currentParser = initializeFrameLengthParser();
}
if (incoming.hasRemaining()) {
currentParser.parse(incoming);
}
}
}
@Override
public void reset(int nextExpectedReadSize) {
// Allocate a new Buffer to hold the incoming frame. We must write
// back the frame size value before continue on to read the indicated
// frame size minus the size of the AMQP frame size header value.
frame = new Buffer(nextExpectedReadSize);
frame.bigEndianEditor().writeInt(nextExpectedReadSize);
// Reset the length to total length as we do direct write after this.
frame.length = frame.data.length;
}
};
}

View File

@ -31,7 +31,11 @@ public class AmqpHeader {
}
public AmqpHeader(Buffer buffer) {
setBuffer(buffer);
this(buffer, true);
}
public AmqpHeader(Buffer buffer, boolean validate) {
setBuffer(buffer, validate);
}
public int getProtocolId() {
@ -71,14 +75,32 @@ public class AmqpHeader {
}
public void setBuffer(Buffer value) {
if (!value.startsWith(PREFIX) || value.length() != 8) {
setBuffer(value, true);
}
public void setBuffer(Buffer value, boolean validate) {
if (validate && !value.startsWith(PREFIX) || value.length() != 8) {
throw new IllegalArgumentException("Not an AMQP header buffer");
}
buffer = value.buffer();
}
public boolean hasValidPrefix() {
return buffer.startsWith(PREFIX);
}
@Override
public String toString() {
return buffer.toString();
StringBuilder builder = new StringBuilder();
for (int i = 0; i < buffer.length(); ++i) {
char value = (char) buffer.get(i);
if (Character.isLetter(value)) {
builder.append(value);
} else {
builder.append(",");
builder.append((int) value);
}
}
return builder.toString();
}
}

View File

@ -26,17 +26,25 @@ import javax.net.SocketFactory;
import org.apache.activemq.transport.nio.NIOSSLTransport;
import org.apache.activemq.wireformat.WireFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class AmqpNioSslTransport extends NIOSSLTransport {
private final AmqpNioTransportHelper amqpNioTransportHelper = new AmqpNioTransportHelper(this);
private static final Logger LOG = LoggerFactory.getLogger(AmqpNioSslTransport.class);
private final AmqpFrameParser frameReader = new AmqpFrameParser(this);
public AmqpNioSslTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
super(wireFormat, socketFactory, remoteLocation, localLocation);
frameReader.setWireFormat((AmqpWireFormat) wireFormat);
}
public AmqpNioSslTransport(WireFormat wireFormat, Socket socket) throws IOException {
super(wireFormat, socket);
frameReader.setWireFormat((AmqpWireFormat) wireFormat);
}
@Override
@ -49,6 +57,6 @@ public class AmqpNioSslTransport extends NIOSSLTransport {
@Override
protected void processCommand(ByteBuffer plain) throws Exception {
amqpNioTransportHelper.processCommand(plain);
frameReader.parse(plain);
}
}

View File

@ -47,16 +47,20 @@ public class AmqpNioTransport extends TcpTransport {
private SocketChannel channel;
private SelectorSelection selection;
private final AmqpNioTransportHelper amqpNioTransportHelper = new AmqpNioTransportHelper(this);
private final AmqpFrameParser frameReader = new AmqpFrameParser(this);
private ByteBuffer inputBuffer;
public AmqpNioTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
super(wireFormat, socketFactory, remoteLocation, localLocation);
frameReader.setWireFormat((AmqpWireFormat) wireFormat);
}
public AmqpNioTransport(WireFormat wireFormat, Socket socket) throws IOException {
super(wireFormat, socket);
frameReader.setWireFormat((AmqpWireFormat) wireFormat);
}
@Override
@ -111,9 +115,7 @@ public class AmqpNioTransport extends TcpTransport {
receiveCounter += readSize;
inputBuffer.flip();
amqpNioTransportHelper.processCommand(inputBuffer);
// clear the buffer
frameReader.parse(inputBuffer);
inputBuffer.clear();
}
} catch (IOException e) {

View File

@ -1,180 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.amqp;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.activemq.transport.TransportSupport;
import org.fusesource.hawtbuf.Buffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class AmqpNioTransportHelper {
private final DataInputStream amqpHeaderValue = new DataInputStream(new ByteArrayInputStream(new byte[] { 'A', 'M', 'Q', 'P' }));
private final Integer AMQP_HEADER_VALUE;
private static final Logger LOG = LoggerFactory.getLogger(AmqpNioTransportHelper.class);
protected int nextFrameSize = -1;
protected ByteBuffer currentBuffer;
private boolean magicConsumed = false;
private final TransportSupport transportSupport;
public AmqpNioTransportHelper(TransportSupport transportSupport) throws IOException {
AMQP_HEADER_VALUE = amqpHeaderValue.readInt();
this.transportSupport = transportSupport;
}
protected void processCommand(ByteBuffer plain) throws Exception {
// Are we waiting for the next Command or building on the current one?
// The frame size is in the first 4 bytes.
if (nextFrameSize == -1) {
// We can get small packets that don't give us enough for the frame
// size so allocate enough for the initial size value and
if (plain.remaining() < 4) {
if (currentBuffer == null) {
currentBuffer = ByteBuffer.allocate(4);
}
// Go until we fill the integer sized current buffer.
while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
currentBuffer.put(plain.get());
}
// Didn't we get enough yet to figure out next frame size.
if (currentBuffer.hasRemaining()) {
return;
} else {
currentBuffer.flip();
nextFrameSize = currentBuffer.getInt();
}
} else {
// Either we are completing a previous read of the next frame
// size or its fully contained in plain already.
if (currentBuffer != null) {
// Finish the frame size integer read and get from the
// current buffer.
while (currentBuffer.hasRemaining()) {
currentBuffer.put(plain.get());
}
currentBuffer.flip();
nextFrameSize = currentBuffer.getInt();
} else {
nextFrameSize = plain.getInt();
}
}
}
// There are three possibilities when we get here. We could have a
// partial frame, a full frame, or more than 1 frame
while (true) {
// handle headers, which start with 'A','M','Q','P' rather than size
if (nextFrameSize == AMQP_HEADER_VALUE) {
nextFrameSize = handleAmqpHeader(plain);
if (nextFrameSize == -1) {
return;
}
}
validateFrameSize(nextFrameSize);
// now we have the data, let's reallocate and try to fill it,
// (currentBuffer.putInt() is called TODO update
// because we need to put back the 4 bytes we read to determine the
// size)
if (currentBuffer == null || (currentBuffer.limit() == 4)) {
currentBuffer = ByteBuffer.allocate(nextFrameSize);
currentBuffer.putInt(nextFrameSize);
}
if (currentBuffer.remaining() >= plain.remaining()) {
currentBuffer.put(plain);
} else {
byte[] fill = new byte[currentBuffer.remaining()];
plain.get(fill);
currentBuffer.put(fill);
}
// Either we have enough data for a new command or we have to wait for some more.
// If hasRemaining is true, we have not filled the buffer yet, i.e. we haven't
// received the full frame.
if (currentBuffer.hasRemaining()) {
return;
} else {
currentBuffer.flip();
LOG.debug("Calling doConsume with position {} limit {}", currentBuffer.position(), currentBuffer.limit());
transportSupport.doConsume(AmqpSupport.toBuffer(currentBuffer));
currentBuffer = null;
nextFrameSize = -1;
// Determine if there are more frames to process
if (plain.hasRemaining()) {
if (plain.remaining() < 4) {
currentBuffer = ByteBuffer.allocate(4);
while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
currentBuffer.put(plain.get());
}
return;
} else {
nextFrameSize = plain.getInt();
}
} else {
return;
}
}
}
}
private void validateFrameSize(int frameSize) throws IOException {
if (nextFrameSize > AmqpWireFormat.DEFAULT_MAX_FRAME_SIZE) {
throw new IOException("Frame size of " + nextFrameSize + "larger than max allowed " + AmqpWireFormat.DEFAULT_MAX_FRAME_SIZE);
}
}
private int handleAmqpHeader(ByteBuffer plain) {
int nextFrameSize;
LOG.debug("Consuming AMQP_HEADER");
currentBuffer = ByteBuffer.allocate(8);
currentBuffer.putInt(AMQP_HEADER_VALUE);
while (currentBuffer.hasRemaining()) {
currentBuffer.put(plain.get());
}
currentBuffer.flip();
if (!magicConsumed) { // The first case we see is special and has to be handled differently
transportSupport.doConsume(new AmqpHeader(new Buffer(currentBuffer)));
magicConsumed = true;
} else {
transportSupport.doConsume(AmqpSupport.toBuffer(currentBuffer));
}
currentBuffer = null;
if (plain.hasRemaining()) {
if (plain.remaining() < 4) {
nextFrameSize = 4;
} else {
nextFrameSize = plain.getInt();
}
} else {
nextFrameSize = -1;
}
return nextFrameSize;
}
}

View File

@ -127,6 +127,7 @@ class AmqpProtocolConverter implements IAmqpProtocolConverter {
private static final Symbol DURABLE_SUBSCRIPTION_ENDED = Symbol.getSymbol("DURABLE_SUBSCRIPTION_ENDED");
private final AmqpTransport amqpTransport;
private final AmqpWireFormat amqpWireFormat;
private final BrokerService brokerService;
protected int prefetch;
@ -137,6 +138,7 @@ class AmqpProtocolConverter implements IAmqpProtocolConverter {
public AmqpProtocolConverter(AmqpTransport transport, BrokerService brokerService) {
this.amqpTransport = transport;
this.amqpWireFormat = transport.getWireFormat();
this.brokerService = brokerService;
// the configured maxFrameSize on the URI.
@ -226,6 +228,17 @@ class AmqpProtocolConverter implements IAmqpProtocolConverter {
Buffer frame;
if (command.getClass() == AmqpHeader.class) {
AmqpHeader header = (AmqpHeader) command;
if (amqpWireFormat.isHeaderValid(header)) {
LOG.trace("Connection from an AMQP v1.0 client initiated. {}", header);
} else {
LOG.warn("Connection attempt from non AMQP v1.0 client. {}", header);
AmqpHeader reply = amqpWireFormat.getMinimallySupportedHeader();
amqpTransport.sendToAmqp(reply.getBuffer());
handleException(new AmqpProtocolException(
"Connection from client using unsupported AMQP attempted", true));
}
switch (header.getProtocolId()) {
case 0:
break; // nothing to do..
@ -270,12 +283,12 @@ class AmqpProtocolConverter implements IAmqpProtocolConverter {
// We can't really auth at this point since we don't
// know the client id yet.. :(
sasl.done(Sasl.SaslOutcome.PN_SASL_OK);
amqpTransport.getWireFormat().magicRead = false;
amqpTransport.getWireFormat().resetMagicRead();
sasl = null;
LOG.debug("SASL [PLAIN] Handshake complete.");
} else if ("ANONYMOUS".equals(sasl.getRemoteMechanisms()[0])) {
sasl.done(Sasl.SaslOutcome.PN_SASL_OK);
amqpTransport.getWireFormat().magicRead = false;
amqpTransport.getWireFormat().resetMagicRead();
sasl = null;
LOG.debug("SASL [ANONYMOUS] Handshake complete.");
}

View File

@ -36,11 +36,21 @@ public class AmqpWireFormat implements WireFormat {
public static final long DEFAULT_MAX_FRAME_SIZE = Long.MAX_VALUE;
public static final int NO_AMQP_MAX_FRAME_SIZE = -1;
private static final int SASL_PROTOCOL = 3;
private int version = 1;
private long maxFrameSize = DEFAULT_MAX_FRAME_SIZE;
private int maxAmqpFrameSize = NO_AMQP_MAX_FRAME_SIZE;
private boolean magicRead = false;
private ResetListener resetListener;
public interface ResetListener {
void onProtocolReset();
}
private boolean allowNonSaslConnections = true;
@Override
public ByteSequence marshal(Object command) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
@ -76,15 +86,13 @@ public class AmqpWireFormat implements WireFormat {
}
}
boolean magicRead = false;
@Override
public Object unmarshal(DataInput dataIn) throws IOException {
if (!magicRead) {
Buffer magic = new Buffer(8);
magic.readFrom(dataIn);
magicRead = true;
return new AmqpHeader(magic);
return new AmqpHeader(magic, false);
} else {
int size = dataIn.readInt();
if (size > maxFrameSize) {
@ -98,19 +106,73 @@ public class AmqpWireFormat implements WireFormat {
}
}
/**
* Given an AMQP header validate that the AMQP magic is present and
* if so that the version and protocol values align with what we support.
*
* @param header
* the header instance received from the client.
*
* @return true if the header is valid against the current WireFormat.
*/
public boolean isHeaderValid(AmqpHeader header) {
if (!header.hasValidPrefix()) {
return false;
}
if (!isAllowNonSaslConnections() && header.getProtocolId() != SASL_PROTOCOL) {
return false;
}
if (header.getMajor() != 1 || header.getMinor() != 0 || header.getRevision() != 0) {
return false;
}
return true;
}
/**
* Returns an AMQP Header object that represents the minimally protocol
* versions supported by this transport. A client that attempts to
* connect with an AMQP version that doesn't at least meat this value
* will receive this prior to the connection being closed.
*
* @return the minimal AMQP version needed from the client.
*/
public AmqpHeader getMinimallySupportedHeader() {
AmqpHeader header = new AmqpHeader();
if (!isAllowNonSaslConnections()) {
header.setProtocolId(3);
}
return header;
}
@Override
public void setVersion(int version) {
this.version = version;
}
/**
* @return the version of the wire format
*/
@Override
public int getVersion() {
return this.version;
}
public void resetMagicRead() {
this.magicRead = false;
if (resetListener != null) {
resetListener.onProtocolReset();
}
}
public void setProtocolResetListener(ResetListener listener) {
this.resetListener = listener;
}
public boolean isMagicRead() {
return this.magicRead;
}
public long getMaxFrameSize() {
return maxFrameSize;
}
@ -126,4 +188,12 @@ public class AmqpWireFormat implements WireFormat {
public void setMaxAmqpFrameSize(int maxAmqpFrameSize) {
this.maxAmqpFrameSize = maxAmqpFrameSize;
}
public boolean isAllowNonSaslConnections() {
return allowNonSaslConnections;
}
public void setAllowNonSaslConnections(boolean allowNonSaslConnections) {
this.allowNonSaslConnections = allowNonSaslConnections;
}
}

View File

@ -16,10 +16,16 @@
*/
package org.apache.activemq.transport.amqp;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
@ -52,8 +58,6 @@ import org.objectweb.jtests.jms.framework.TestConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.junit.Assert.*;
public class JMSClientTest extends JMSClientTestSupport {
protected static final Logger LOG = LoggerFactory.getLogger(JMSClientTest.class);
@ -104,36 +108,36 @@ public class JMSClientTest extends JMSClientTestSupport {
}
}
@Test(timeout=30000)
@Test // (timeout=30000)
public void testAnonymousProducerConsume() throws Exception {
ActiveMQAdmin.enableJMSFrameTracing();
connection = createConnection();
{
Session session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE);
Queue queue1 = session.createQueue(getDestinationName() + "1");
Queue queue2 = session.createQueue(getDestinationName() + "2");
MessageProducer p = session.createProducer(null);
TextMessage message = session.createTextMessage();
message.setText("hello");
p.send(queue1, message);
p.send(queue2, message);
{
MessageConsumer consumer = session.createConsumer(queue1);
Message msg = consumer.receive(TestConfig.TIMEOUT);
assertNotNull(msg);
assertTrue(msg instanceof TextMessage);
consumer.close();
}
{
MessageConsumer consumer = session.createConsumer(queue2);
Message msg = consumer.receive(TestConfig.TIMEOUT);
assertNotNull(msg);
assertTrue(msg instanceof TextMessage);
consumer.close();
}
// Session session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE);
// Queue queue1 = session.createQueue(getDestinationName() + "1");
// Queue queue2 = session.createQueue(getDestinationName() + "2");
// MessageProducer p = session.createProducer(null);
//
// TextMessage message = session.createTextMessage();
// message.setText("hello");
// p.send(queue1, message);
// p.send(queue2, message);
//
// {
// MessageConsumer consumer = session.createConsumer(queue1);
// Message msg = consumer.receive(TestConfig.TIMEOUT);
// assertNotNull(msg);
// assertTrue(msg instanceof TextMessage);
// consumer.close();
// }
// {
// MessageConsumer consumer = session.createConsumer(queue2);
// Message msg = consumer.receive(TestConfig.TIMEOUT);
// assertNotNull(msg);
// assertTrue(msg instanceof TextMessage);
// consumer.close();
// }
}
}

View File

@ -0,0 +1,351 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.amqp.protocol;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import org.apache.activemq.transport.amqp.AmqpFrameParser;
import org.apache.activemq.transport.amqp.AmqpHeader;
import org.apache.activemq.transport.amqp.AmqpWireFormat;
import org.fusesource.hawtbuf.Buffer;
import org.fusesource.hawtbuf.DataByteArrayOutputStream;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class AmqpFrameParserTest {
private static final Logger LOG = LoggerFactory.getLogger(AmqpFrameParserTest.class);
private final AmqpWireFormat amqpWireFormat = new AmqpWireFormat();
private List<Object> frames;
private AmqpFrameParser codec;
private final int MESSAGE_SIZE = 5 * 1024 * 1024;
@Before
public void setUp() throws Exception {
frames = new ArrayList<Object>();
codec = new AmqpFrameParser(new AmqpFrameParser.AMQPFrameSink() {
@Override
public void onFrame(Object frame) {
frames.add(frame);
}
});
codec.setWireFormat(amqpWireFormat);
}
@Test
public void testAMQPHeaderReadEmptyBuffer() throws Exception {
codec.parse(ByteBuffer.allocate(0));
}
@Test
public void testAMQPHeaderReadNull() throws Exception {
codec.parse((ByteBuffer) null);
}
@Test
public void testAMQPHeaderRead() throws Exception {
AmqpHeader inputHeader = new AmqpHeader();
codec.parse(inputHeader.getBuffer().toByteBuffer());
assertEquals(1, frames.size());
Object outputFrame = frames.get(0);
assertTrue(outputFrame instanceof AmqpHeader);
AmqpHeader outputHeader = (AmqpHeader) outputFrame;
assertHeadersEqual(inputHeader, outputHeader);
}
@Test
public void testAMQPHeaderReadSingleByteReads() throws Exception {
AmqpHeader inputHeader = new AmqpHeader();
for (int i = 0; i < inputHeader.getBuffer().length(); ++i) {
codec.parse(inputHeader.getBuffer().slice(i, i+1).toByteBuffer());
}
assertEquals(1, frames.size());
Object outputFrame = frames.get(0);
assertTrue(outputFrame instanceof AmqpHeader);
AmqpHeader outputHeader = (AmqpHeader) outputFrame;
assertHeadersEqual(inputHeader, outputHeader);
}
@Test
public void testResetReadsNextAMQPHeaderMidParse() throws Exception {
AmqpHeader inputHeader = new AmqpHeader();
DataByteArrayOutputStream headers = new DataByteArrayOutputStream();
headers.write(inputHeader.getBuffer());
headers.write(inputHeader.getBuffer());
headers.write(inputHeader.getBuffer());
headers.close();
codec = new AmqpFrameParser(new AmqpFrameParser.AMQPFrameSink() {
@Override
public void onFrame(Object frame) {
frames.add(frame);
codec.reset();
}
});
codec.parse(headers.toBuffer().toByteBuffer());
assertEquals(3, frames.size());
for (Object header : frames) {
assertTrue(header instanceof AmqpHeader);
AmqpHeader outputHeader = (AmqpHeader) header;
assertHeadersEqual(inputHeader, outputHeader);
}
}
@Test
public void testResetReadsNextAMQPHeader() throws Exception {
AmqpHeader inputHeader = new AmqpHeader();
for (int i = 1; i <= 3; ++i) {
codec.parse(inputHeader.getBuffer().toByteBuffer());
codec.reset();
assertEquals(i, frames.size());
Object outputFrame = frames.get(i - 1);
assertTrue(outputFrame instanceof AmqpHeader);
AmqpHeader outputHeader = (AmqpHeader) outputFrame;
assertHeadersEqual(inputHeader, outputHeader);
}
}
@Test
public void testResetReadsNextAMQPHeaderAfterContentParsed() throws Exception {
AmqpHeader inputHeader = new AmqpHeader();
byte[] CONTENTS = new byte[MESSAGE_SIZE];
for (int i = 0; i < MESSAGE_SIZE; i++) {
CONTENTS[i] = 'a';
}
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
output.write(inputHeader.getBuffer());
output.writeInt(MESSAGE_SIZE + 4);
output.write(CONTENTS);
output.write(inputHeader.getBuffer());
output.writeInt(MESSAGE_SIZE + 4);
output.write(CONTENTS);
output.close();
codec = new AmqpFrameParser(new AmqpFrameParser.AMQPFrameSink() {
@Override
public void onFrame(Object frame) {
frames.add(frame);
if (!(frame instanceof AmqpHeader)) {
codec.reset();
}
}
});
codec.parse(output.toBuffer().toByteBuffer());
for (int i = 0; i < 4; ++i) {
Object frame = frames.get(i);
assertTrue(frame instanceof AmqpHeader);
AmqpHeader outputHeader = (AmqpHeader) frame;
assertHeadersEqual(inputHeader, outputHeader);
frame = frames.get(++i);
assertFalse(frame instanceof AmqpHeader);
assertTrue(frame instanceof Buffer);
assertEquals(MESSAGE_SIZE + 4, ((Buffer) frame).getLength());
}
}
@Test
public void testHeaderAndFrameAreRead() throws Exception {
AmqpHeader inputHeader = new AmqpHeader();
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
byte[] CONTENTS = new byte[MESSAGE_SIZE];
for (int i = 0; i < MESSAGE_SIZE; i++) {
CONTENTS[i] = 'a';
}
output.write(inputHeader.getBuffer());
output.writeInt(MESSAGE_SIZE + 4);
output.write(CONTENTS);
output.close();
codec.parse(output.toBuffer().toByteBuffer());
assertEquals(2, frames.size());
Object outputFrame = frames.get(0);
assertTrue(outputFrame instanceof AmqpHeader);
AmqpHeader outputHeader = (AmqpHeader) outputFrame;
assertHeadersEqual(inputHeader, outputHeader);
outputFrame = frames.get(1);
assertTrue(outputFrame instanceof Buffer);
Buffer frame = (Buffer) outputFrame;
assertEquals(MESSAGE_SIZE + 4, frame.length());
}
@Test
public void testHeaderAndFrameAreReadNoWireFormat() throws Exception {
codec.setWireFormat(null);
AmqpHeader inputHeader = new AmqpHeader();
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
byte[] CONTENTS = new byte[MESSAGE_SIZE];
for (int i = 0; i < MESSAGE_SIZE; i++) {
CONTENTS[i] = 'a';
}
output.write(inputHeader.getBuffer());
output.writeInt(MESSAGE_SIZE + 4);
output.write(CONTENTS);
output.close();
codec.parse(output.toBuffer().toByteBuffer());
assertEquals(2, frames.size());
Object outputFrame = frames.get(0);
assertTrue(outputFrame instanceof AmqpHeader);
AmqpHeader outputHeader = (AmqpHeader) outputFrame;
assertHeadersEqual(inputHeader, outputHeader);
outputFrame = frames.get(1);
assertTrue(outputFrame instanceof Buffer);
Buffer frame = (Buffer) outputFrame;
assertEquals(MESSAGE_SIZE + 4, frame.length());
}
@Test
public void testHeaderAndMulitpleFramesAreRead() throws Exception {
AmqpHeader inputHeader = new AmqpHeader();
final int FRAME_SIZE_HEADER = 4;
final int FRAME_SIZE = 65531;
final int NUM_FRAMES = 5;
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
byte[] CONTENTS = new byte[FRAME_SIZE];
for (int i = 0; i < FRAME_SIZE; i++) {
CONTENTS[i] = 'a';
}
output.write(inputHeader.getBuffer());
for (int i = 0; i < NUM_FRAMES; ++i) {
output.writeInt(FRAME_SIZE + FRAME_SIZE_HEADER);
output.write(CONTENTS);
}
output.close();
codec.parse(output.toBuffer().toByteBuffer());
assertEquals(NUM_FRAMES + 1, frames.size());
Object outputFrame = frames.get(0);
assertTrue(outputFrame instanceof AmqpHeader);
AmqpHeader outputHeader = (AmqpHeader) outputFrame;
assertHeadersEqual(inputHeader, outputHeader);
for (int i = 1; i <= NUM_FRAMES; ++i) {
outputFrame = frames.get(i);
assertTrue(outputFrame instanceof Buffer);
Buffer frame = (Buffer) outputFrame;
assertEquals(FRAME_SIZE + FRAME_SIZE_HEADER, frame.length());
}
}
@Test
public void testCodecRejectsToLargeFrames() throws Exception {
amqpWireFormat.setMaxFrameSize(MESSAGE_SIZE);
AmqpHeader inputHeader = new AmqpHeader();
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
byte[] CONTENTS = new byte[MESSAGE_SIZE];
for (int i = 0; i < MESSAGE_SIZE; i++) {
CONTENTS[i] = 'a';
}
output.write(inputHeader.getBuffer());
output.writeInt(MESSAGE_SIZE + 4);
output.write(CONTENTS);
output.close();
try {
codec.parse(output.toBuffer().toByteBuffer());
fail("Should have failed to read the large frame.");
} catch (Exception ex) {
LOG.debug("Caught expected error: {}", ex.getMessage());
}
}
@Test
public void testReadPartialPayload() throws Exception {
AmqpHeader inputHeader = new AmqpHeader();
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
byte[] HALF_CONTENT = new byte[MESSAGE_SIZE / 2];
for (int i = 0; i < MESSAGE_SIZE / 2; i++) {
HALF_CONTENT[i] = 'a';
}
output.write(inputHeader.getBuffer());
output.writeInt(MESSAGE_SIZE + 4);
output.close();
codec.parse(output.toBuffer().toByteBuffer());
assertEquals(1, frames.size());
output = new DataByteArrayOutputStream();
output.write(HALF_CONTENT);
output.close();
codec.parse(output.toBuffer().toByteBuffer());
assertEquals(1, frames.size());
output = new DataByteArrayOutputStream();
output.write(HALF_CONTENT);
output.close();
codec.parse(output.toBuffer().toByteBuffer());
assertEquals(2, frames.size());
}
private void assertHeadersEqual(AmqpHeader expected, AmqpHeader actual) {
assertTrue(expected.getBuffer().equals(actual.getBuffer()));
}
}

View File

@ -0,0 +1,70 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.amqp.protocol;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.activemq.transport.amqp.AmqpHeader;
import org.apache.activemq.transport.amqp.AmqpWireFormat;
import org.apache.activemq.transport.amqp.AmqpWireFormat.ResetListener;
import org.junit.Test;
public class AmqpWireFormatTest {
private final AmqpWireFormat wireFormat = new AmqpWireFormat();
@Test
public void testWhenSaslNotAllowedNonSaslHeaderIsInvliad() {
wireFormat.setAllowNonSaslConnections(false);
AmqpHeader nonSaslHeader = new AmqpHeader();
assertFalse(wireFormat.isHeaderValid(nonSaslHeader));
AmqpHeader saslHeader = new AmqpHeader();
saslHeader.setProtocolId(3);
assertTrue(wireFormat.isHeaderValid(saslHeader));
}
@Test
public void testWhenSaslAllowedNonSaslHeaderIsValid() {
wireFormat.setAllowNonSaslConnections(true);
AmqpHeader nonSaslHeader = new AmqpHeader();
assertTrue(wireFormat.isHeaderValid(nonSaslHeader));
AmqpHeader saslHeader = new AmqpHeader();
saslHeader.setProtocolId(3);
assertTrue(wireFormat.isHeaderValid(saslHeader));
}
@Test
public void testMagicResetListener() throws Exception {
final AtomicBoolean reset = new AtomicBoolean();
wireFormat.setProtocolResetListener(new ResetListener() {
@Override
public void onProtocolReset() {
reset.set(true);
}
});
wireFormat.resetMagicRead();
assertTrue(reset.get());
}
}

View File

@ -0,0 +1,258 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.amqp.protocol;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.net.UnknownHostException;
import javax.net.SocketFactory;
import javax.net.ssl.SSLSocketFactory;
import org.apache.activemq.transport.amqp.AmqpHeader;
import org.apache.activemq.transport.amqp.AmqpTestSupport;
import org.apache.activemq.util.Wait;
import org.fusesource.hawtbuf.Buffer;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Test that the Broker handles connections from older clients or
* non-AMQP client correctly by returning an AMQP header prior to
* closing the socket.
*/
public class UnsupportedClientTest extends AmqpTestSupport {
private static final Logger LOG = LoggerFactory.getLogger(UnsupportedClientTest.class);
@Override
@Before
public void setUp() throws Exception {
System.setProperty("javax.net.ssl.trustStore", "src/test/resources/client.keystore");
System.setProperty("javax.net.ssl.trustStorePassword", "password");
System.setProperty("javax.net.ssl.trustStoreType", "jks");
System.setProperty("javax.net.ssl.keyStore", "src/test/resources/server.keystore");
System.setProperty("javax.net.ssl.keyStorePassword", "password");
System.setProperty("javax.net.ssl.keyStoreType", "jks");
super.setUp();
}
@Test(timeout = 60000)
public void testOlderProtocolIsRejected() throws Exception {
AmqpHeader header = new AmqpHeader();
header.setMajor(0);
header.setMinor(9);
header.setRevision(1);
// Test TCP
doTestInvalidHeaderProcessing(port, header, false);
// Test SSL
doTestInvalidHeaderProcessing(sslPort, header, true);
// Test NIO
doTestInvalidHeaderProcessing(nioPort, header, false);
// Test NIO+SSL
doTestInvalidHeaderProcessing(nioPlusSslPort, header, true);
}
@Test(timeout = 60000)
public void testNewerMajorIsRejected() throws Exception {
AmqpHeader header = new AmqpHeader();
header.setMajor(2);
header.setMinor(0);
header.setRevision(0);
// Test TCP
doTestInvalidHeaderProcessing(port, header, false);
// Test SSL
doTestInvalidHeaderProcessing(sslPort, header, true);
// Test NIO
doTestInvalidHeaderProcessing(nioPort, header, false);
// Test NIO+SSL
doTestInvalidHeaderProcessing(nioPlusSslPort, header, true);
}
@Test(timeout = 60000)
public void testNewerMinorIsRejected() throws Exception {
AmqpHeader header = new AmqpHeader();
header.setMajor(1);
header.setMinor(1);
header.setRevision(0);
// Test TCP
doTestInvalidHeaderProcessing(port, header, false);
// Test SSL
doTestInvalidHeaderProcessing(sslPort, header, true);
// Test NIO
doTestInvalidHeaderProcessing(nioPort, header, false);
// Test NIO+SSL
doTestInvalidHeaderProcessing(nioPlusSslPort, header, true);
}
@Test(timeout = 60000)
public void testNewerRevisionIsRejected() throws Exception {
AmqpHeader header = new AmqpHeader();
header.setMajor(1);
header.setMinor(0);
header.setRevision(1);
// Test TCP
doTestInvalidHeaderProcessing(port, header, false);
// Test SSL
doTestInvalidHeaderProcessing(sslPort, header, true);
// Test NIO
doTestInvalidHeaderProcessing(nioPort, header, false);
// Test NIO+SSL
doTestInvalidHeaderProcessing(nioPlusSslPort, header, true);
}
@Test(timeout = 60000)
public void testInvalidProtocolHeader() throws Exception {
AmqpHeader header = new AmqpHeader(new Buffer(new byte[]{'S', 'T', 'O', 'M', 'P', 0, 0, 0}), false);
// Test TCP
doTestInvalidHeaderProcessing(port, header, false);
// Test SSL
doTestInvalidHeaderProcessing(sslPort, header, true);
// Test NIO
doTestInvalidHeaderProcessing(nioPort, header, false);
// Test NIO+SSL
doTestInvalidHeaderProcessing(nioPlusSslPort, header, true);
}
protected void doTestInvalidHeaderProcessing(int port, final AmqpHeader header, boolean ssl) throws Exception {
final ClientConnection connection = createClientConnection(ssl);
connection.open("localhost", port);
connection.send(header);
AmqpHeader response = connection.readAmqpHeader();
assertNotNull(response);
LOG.info("Broker responded with: {}", response);
assertTrue("Broker should have closed client connection", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
try {
connection.send(header);
return false;
} catch (Exception e) {
return true;
}
}
}));
}
private ClientConnection createClientConnection(boolean ssl) {
if (ssl) {
return new SslClientConnection();
} else {
return new ClientConnection();
}
}
private class ClientConnection {
protected static final long RECEIVE_TIMEOUT = 10000;
protected Socket clientSocket;
public void open(String host, int port) throws IOException, UnknownHostException {
clientSocket = new Socket(host, port);
clientSocket.setTcpNoDelay(true);
}
public void send(AmqpHeader header) throws Exception {
OutputStream outputStream = clientSocket.getOutputStream();
header.getBuffer().writeTo(outputStream);
outputStream.flush();
}
public AmqpHeader readAmqpHeader() throws Exception {
clientSocket.setSoTimeout((int)RECEIVE_TIMEOUT);
InputStream is = clientSocket.getInputStream();
byte[] header = new byte[8];
int read = is.read(header);
if (read == header.length) {
return new AmqpHeader(new Buffer(header));
} else {
return null;
}
}
}
private class SslClientConnection extends ClientConnection {
@Override
public void open(String host, int port) throws IOException, UnknownHostException {
SocketFactory factory = SSLSocketFactory.getDefault();
clientSocket = factory.createSocket(host, port);
clientSocket.setTcpNoDelay(true);
}
}
@Override
protected boolean isUseTcpConnector() {
return true;
}
@Override
protected boolean isUseSslConnector() {
return true;
}
@Override
protected boolean isUseNioConnector() {
return true;
}
@Override
protected boolean isUseNioPlusSslConnector() {
return true;
}
}