Fix for AMQ-5073, updated AmqpNioSslTransport.java to propery handle frames. Also fixed bugs in amqp test, as seen in AMQ-5062

This commit is contained in:
Kevin Earls 2014-02-28 10:58:15 +01:00
parent dc607bbf35
commit 2360fb8596
8 changed files with 323 additions and 37 deletions

View File

@ -160,6 +160,19 @@
</execution> </execution>
</executions> </executions>
</plugin> </plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<forkCount>1</forkCount>
<reuseForks>false</reuseForks>
<surefire.argLine>-Xmx512M -Djava.awt.headless=true</surefire.argLine>
<runOrder>alphabetical</runOrder>
<forkedProcessTimeoutInSeconds>120</forkedProcessTimeoutInSeconds>
<includes>
<include>**/*Test.*</include>
</includes>
</configuration>
</plugin>
</plugins> </plugins>
</build> </build>
</profile> </profile>

View File

@ -16,21 +16,26 @@
*/ */
package org.apache.activemq.transport.amqp; package org.apache.activemq.transport.amqp;
import org.apache.activemq.transport.nio.NIOSSLTransport;
import org.apache.activemq.wireformat.WireFormat;
import org.fusesource.hawtbuf.Buffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.SocketFactory;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException; import java.io.IOException;
import java.net.Socket; import java.net.Socket;
import java.net.URI; import java.net.URI;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import javax.net.SocketFactory;
import org.apache.activemq.transport.nio.NIOSSLTransport;
import org.apache.activemq.wireformat.WireFormat;
import org.fusesource.hawtbuf.Buffer;
public class AmqpNioSslTransport extends NIOSSLTransport { public class AmqpNioSslTransport extends NIOSSLTransport {
private DataInputStream amqpHeaderValue = new DataInputStream(new ByteArrayInputStream(new byte[]{'A', 'M', 'Q', 'P'}));
private final ByteBuffer magic = ByteBuffer.allocate(8); public final Integer AMQP_HEADER_VALUE = amqpHeaderValue.readInt();
private static final Logger LOG = LoggerFactory.getLogger(AmqpNioSslTransport.class);
private boolean magicConsumed = false;
public AmqpNioSslTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { public AmqpNioSslTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
super(wireFormat, socketFactory, remoteLocation, localLocation); super(wireFormat, socketFactory, remoteLocation, localLocation);
@ -50,27 +55,131 @@ public class AmqpNioSslTransport extends NIOSSLTransport {
@Override @Override
protected void processCommand(ByteBuffer plain) throws Exception { protected void processCommand(ByteBuffer plain) throws Exception {
// Are we waiting for the next Command or are we building on the current one? The
// frame size is in the first 4 bytes.
if (nextFrameSize == -1) {
// We can get small packets that don't give us enough for the frame size
// so allocate enough for the initial size value and
if (plain.remaining() < 4) {
if (currentBuffer == null) {
currentBuffer = ByteBuffer.allocate(4);
}
byte[] fill = new byte[plain.remaining()]; // Go until we fill the integer sized current buffer.
plain.get(fill); while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
currentBuffer.put(plain.get());
}
ByteBuffer payload = ByteBuffer.wrap(fill); // Didn't we get enough yet to figure out next frame size.
if (currentBuffer.hasRemaining()) {
return;
} else {
currentBuffer.flip();
nextFrameSize = currentBuffer.getInt();
}
} else {
// Either we are completing a previous read of the next frame size or its
// fully contained in plain already.
if (currentBuffer != null) {
// Finish the frame size integer read and get from the current buffer.
while (currentBuffer.hasRemaining()) {
currentBuffer.put(plain.get());
}
if (magic.position() != 8) { currentBuffer.flip();
nextFrameSize = currentBuffer.getInt();
while (payload.hasRemaining() && magic.position() < 8) { } else {
magic.put(payload.get()); nextFrameSize = plain.getInt();
} }
if (!magic.hasRemaining()) {
magic.flip();
doConsume(new AmqpHeader(new Buffer(magic)));
magic.position(8);
} }
} }
if (payload.hasRemaining()) { // There are three possibilities when we get here. We could have a partial frame,
doConsume(AmqpSupport.toBuffer(payload)); // a full frame, or more than 1 frame
while (true) {
LOG.debug("Entering while loop with plain.position {} remaining {} ", plain.position(), plain.remaining());
// handle headers, which start with 'A','M','Q','P' rather than size
if (nextFrameSize == AMQP_HEADER_VALUE) {
nextFrameSize = handleAmqpHeader(plain);
if (nextFrameSize == -1) {
return;
}
}
validateFrameSize(nextFrameSize);
// now we have the data, let's reallocate and try to fill it, (currentBuffer.putInt() is called
// because we need to put back the 4 bytes we read to determine the size)
currentBuffer = ByteBuffer.allocate(nextFrameSize );
currentBuffer.putInt(nextFrameSize);
if (currentBuffer.remaining() >= plain.remaining()) {
currentBuffer.put(plain);
} else {
byte[] fill = new byte[currentBuffer.remaining()];
plain.get(fill);
currentBuffer.put(fill);
}
// Either we have enough data for a new command or we have to wait for some more. If hasRemaining is true,
// we have not filled the buffer yet, i.e. we haven't received the full frame.
if (currentBuffer.hasRemaining()) {
return;
} else {
currentBuffer.flip();
LOG.debug("Calling doConsume with position {} limit {}", currentBuffer.position(), currentBuffer.limit());
doConsume(AmqpSupport.toBuffer(currentBuffer));
// Determine if there are more frames to process
if (plain.hasRemaining()) {
if (plain.remaining() < 4) {
nextFrameSize = 4;
} else {
nextFrameSize = plain.getInt();
}
} else {
nextFrameSize = -1;
currentBuffer = null;
return;
}
}
} }
} }
private void validateFrameSize(int frameSize) throws IOException {
if (nextFrameSize > AmqpWireFormat.DEFAULT_MAX_FRAME_SIZE) {
throw new IOException("Frame size of " + nextFrameSize +
"larger than max allowed " + AmqpWireFormat.DEFAULT_MAX_FRAME_SIZE);
}
}
private int handleAmqpHeader(ByteBuffer plain) {
int nextFrameSize;
LOG.debug("Consuming AMQP_HEADER");
currentBuffer = ByteBuffer.allocate(8);
currentBuffer.putInt(AMQP_HEADER_VALUE);
while (currentBuffer.hasRemaining()) {
currentBuffer.put(plain.get());
}
currentBuffer.flip();
if (!magicConsumed) { // The first case we see is special and has to be handled differently
doConsume(new AmqpHeader(new Buffer(currentBuffer)));
magicConsumed = true;
} else {
doConsume(AmqpSupport.toBuffer(currentBuffer));
}
if (plain.hasRemaining()) {
if (plain.remaining() < 4) {
nextFrameSize = 4;
} else {
nextFrameSize = plain.getInt();
}
} else {
nextFrameSize = -1;
currentBuffer = null;
}
return nextFrameSize;
}
} }

View File

@ -20,6 +20,12 @@ import java.io.File;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.Set; import java.util.Set;
import java.util.Vector; import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.jms.Connection; import javax.jms.Connection;
import javax.jms.Destination; import javax.jms.Destination;
@ -68,7 +74,20 @@ public class AmqpTestSupport {
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
exceptions.clear(); exceptions.clear();
startBroker(); if (killHungThreads("setUp")) {
LOG.warn("HUNG THREADS in setUp");
}
ExecutorService executor = Executors.newSingleThreadExecutor();
Future<Boolean> future = executor.submit(new SetUpTask());
try {
LOG.debug("SetUpTask started.");
Boolean result = future.get(60, TimeUnit.SECONDS);
} catch (TimeoutException e) {
throw new Exception("startBroker timed out");
}
executor.shutdownNow();
this.numberOfMessages = 2000; this.numberOfMessages = 2000;
} }
@ -130,16 +149,51 @@ public class AmqpTestSupport {
} }
public void stopBroker() throws Exception { public void stopBroker() throws Exception {
LOG.debug("entering AmqpTestSupport.stopBroker");
if (brokerService != null) { if (brokerService != null) {
brokerService.stop(); brokerService.stop();
brokerService.waitUntilStopped(); brokerService.waitUntilStopped();
brokerService = null; brokerService = null;
} }
LOG.debug("exiting AmqpTestSupport.stopBroker");
} }
@After @After
public void tearDown() throws Exception { public void tearDown() throws Exception {
stopBroker(); ExecutorService executor = Executors.newSingleThreadExecutor();
Future<Boolean> future = executor.submit(new TearDownTask());
try {
LOG.debug("tearDown started.");
Boolean result = future.get(60, TimeUnit.SECONDS);
} catch (TimeoutException e) {
throw new Exception("startBroker timed out");
}
executor.shutdownNow();
if (killHungThreads("tearDown")) {
LOG.warn("HUNG THREADS in setUp");
}
}
private boolean killHungThreads(String stage) throws Exception{
Thread.sleep(500);
if (Thread.activeCount() == 1) {
return false;
}
LOG.warn("Hung Thread(s) on {} entry threadCount {} ", stage, Thread.activeCount());
Thread[] threads = new Thread[Thread.activeCount()];
Thread.enumerate(threads);
for (int i=0; i < threads.length; i++) {
Thread t = threads[i];
if (!t.getName().equals("main")) {
LOG.warn("KillHungThreads: Interrupting thread {}", t.getName());
t.interrupt();
}
}
LOG.warn("Hung Thread on {} exit threadCount {} ", stage, Thread.activeCount());
return true;
} }
public void sendMessages(Connection connection, Destination destination, int count) throws Exception { public void sendMessages(Connection connection, Destination destination, int count) throws Exception {
@ -191,4 +245,29 @@ public class AmqpTestSupport {
.newProxyInstance(queueViewMBeanName, QueueViewMBean.class, true); .newProxyInstance(queueViewMBeanName, QueueViewMBean.class, true);
return proxy; return proxy;
} }
public class SetUpTask implements Callable<Boolean> {
private String testName;
@Override
public Boolean call() throws Exception {
LOG.debug("in SetUpTask.call, calling startBroker");
startBroker();
return Boolean.TRUE;
}
}
public class TearDownTask implements Callable<Boolean> {
private String testName;
@Override
public Boolean call() throws Exception {
LOG.debug("in TearDownTask.call(), calling stopBroker");
stopBroker();
return Boolean.TRUE;
}
}
} }

View File

@ -0,0 +1,33 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.amqp;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Test the JMS client when connected to the NIO+SSL transport.
*/
public class JMSClientNioPlusSslTest extends JMSClientSslTest {
protected static final Logger LOG = LoggerFactory.getLogger(JMSClientNioPlusSslTest.class);
@Override
protected int getBrokerPort() {
LOG.debug("JMSClientNioPlusSslTest.getBrokerPort returning nioPlusSslPort {}", nioPlusSslPort);
return nioPlusSslPort;
}
}

View File

@ -29,12 +29,12 @@ import java.io.DataInputStream;
/** /**
* Test the JMS client when connected to the NIO transport. * Test the JMS client when connected to the NIO transport.
*/ */
@Ignore
public class JMSClientNioTest extends JMSClientTest { public class JMSClientNioTest extends JMSClientTest {
protected static final Logger LOG = LoggerFactory.getLogger(JMSClientNioTest.class); protected static final Logger LOG = LoggerFactory.getLogger(JMSClientNioTest.class);
@Override @Override
protected int getBrokerPort() { protected int getBrokerPort() {
LOG.debug("JMSClientNioTest.getBrokerPort returning nioPort {}", nioPort);
return nioPort; return nioPort;
} }
} }

View File

@ -0,0 +1,54 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.amqp;
import org.junit.BeforeClass;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.jms.Connection;
import javax.jms.JMSException;
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import java.security.SecureRandom;
/**
* Test the JMS client when connected to the SSL transport.
*/
public class JMSClientSslTest extends JMSClientTest {
protected static final Logger LOG = LoggerFactory.getLogger(JMSClientSslTest.class);
@BeforeClass
public static void beforeClass() throws Exception {
SSLContext ctx = SSLContext.getInstance("TLS");
ctx.init(new KeyManager[0], new TrustManager[]{new DefaultTrustManager()}, new SecureRandom());
SSLContext.setDefault(ctx);
}
@Override
protected Connection createConnection(String clientId, boolean syncPublish, boolean useSsl) throws JMSException {
LOG.debug("JMSClientSslTest.createConnection called with clientId {} syncPublish {} useSsl {}", clientId, syncPublish, useSsl);
return super.createConnection(clientId, syncPublish, true);
}
@Override
protected int getBrokerPort() {
LOG.debug("JMSClientSslTest.getBrokerPort returning sslPort {}", sslPort);
return sslPort;
}
}

View File

@ -57,7 +57,6 @@ import org.objectweb.jtests.jms.framework.TestConfig;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@Ignore
public class JMSClientTest extends AmqpTestSupport { public class JMSClientTest extends AmqpTestSupport {
protected static final Logger LOG = LoggerFactory.getLogger(JMSClientTest.class); protected static final Logger LOG = LoggerFactory.getLogger(JMSClientTest.class);
@Rule public TestName name = new TestName(); @Rule public TestName name = new TestName();
@ -353,7 +352,7 @@ public class JMSClientTest extends AmqpTestSupport {
} }
} }
@Test(timeout=30000) @Test(timeout=90000)
public void testConsumerReceiveNoWaitThrowsWhenBrokerStops() throws Exception { public void testConsumerReceiveNoWaitThrowsWhenBrokerStops() throws Exception {
Connection connection = createConnection(); Connection connection = createConnection();
@ -377,7 +376,7 @@ public class JMSClientTest extends AmqpTestSupport {
try { try {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
consumer.receiveNoWait(); consumer.receiveNoWait();
TimeUnit.SECONDS.sleep(1); TimeUnit.MILLISECONDS.sleep(1000 + (i * 100));
} }
fail("Should have thrown an IllegalStateException"); fail("Should have thrown an IllegalStateException");
} catch (Exception ex) { } catch (Exception ex) {
@ -385,7 +384,7 @@ public class JMSClientTest extends AmqpTestSupport {
} }
} }
@Test(timeout=30000) @Test(timeout=60000)
public void testConsumerReceiveTimedThrowsWhenBrokerStops() throws Exception { public void testConsumerReceiveTimedThrowsWhenBrokerStops() throws Exception {
Connection connection = createConnection(); Connection connection = createConnection();
@ -408,7 +407,7 @@ public class JMSClientTest extends AmqpTestSupport {
try { try {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
consumer.receive(1000); consumer.receive(1000 + (i * 100));
} }
fail("Should have thrown an IllegalStateException"); fail("Should have thrown an IllegalStateException");
} catch (Exception ex) { } catch (Exception ex) {
@ -753,15 +752,15 @@ public class JMSClientTest extends AmqpTestSupport {
} }
private Connection createConnection() throws JMSException { private Connection createConnection() throws JMSException {
return createConnection(name.toString(), false); return createConnection(name.toString(), false, false);
} }
private Connection createConnection(boolean syncPublish) throws JMSException { private Connection createConnection(boolean syncPublish) throws JMSException {
return createConnection(name.toString(), syncPublish); return createConnection(name.toString(), syncPublish, false);
} }
private Connection createConnection(String clientId) throws JMSException { private Connection createConnection(String clientId) throws JMSException {
return createConnection(clientId, false); return createConnection(clientId, false, false);
} }
/** /**
@ -773,11 +772,11 @@ public class JMSClientTest extends AmqpTestSupport {
return port; return port;
} }
private Connection createConnection(String clientId, boolean syncPublish) throws JMSException { protected Connection createConnection(String clientId, boolean syncPublish, boolean useSsl) throws JMSException {
int brokerPort = getBrokerPort(); int brokerPort = getBrokerPort();
LOG.debug("Creating connection on port {}", brokerPort); LOG.debug("Creating connection on port {}", brokerPort);
final ConnectionFactoryImpl factory = new ConnectionFactoryImpl("localhost", brokerPort, "admin", "password"); final ConnectionFactoryImpl factory = new ConnectionFactoryImpl("localhost", brokerPort, "admin", "password", null, useSsl);
factory.setSyncPublish(syncPublish); factory.setSyncPublish(syncPublish);
factory.setTopicPrefix("topic://"); factory.setTopicPrefix("topic://");

View File

@ -48,7 +48,6 @@ import org.objectweb.jtests.jms.conform.topic.TemporaryTopicTest;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@Ignore
@RunWith(Suite.class) @RunWith(Suite.class)
@Suite.SuiteClasses({ @Suite.SuiteClasses({
// TopicSessionTest.class, // Hangs, see https://issues.apache.org/jira/browse/PROTON-154 // TopicSessionTest.class, // Hangs, see https://issues.apache.org/jira/browse/PROTON-154