diff --git a/activemq-amqp/src/main/java/org/apache/activemq/transport/amqp/AmqpInactivityMonitor.java b/activemq-amqp/src/main/java/org/apache/activemq/transport/amqp/AmqpInactivityMonitor.java index 8e6f60d8b2..e7255ea303 100644 --- a/activemq-amqp/src/main/java/org/apache/activemq/transport/amqp/AmqpInactivityMonitor.java +++ b/activemq-amqp/src/main/java/org/apache/activemq/transport/amqp/AmqpInactivityMonitor.java @@ -19,6 +19,7 @@ package org.apache.activemq.transport.amqp; import java.io.IOException; import java.util.Timer; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; @@ -58,15 +59,22 @@ public class AmqpInactivityMonitor extends TransportFilter { public void run() { long now = System.currentTimeMillis(); - if ((now - startTime) >= connectionTimeout && connectCheckerTask != null && !ASYNC_TASKS.isTerminating()) { + if ((now - startTime) >= connectionTimeout && connectCheckerTask != null && !ASYNC_TASKS.isShutdown()) { LOG.debug("No connection attempt made in time for {}! Throwing InactivityIOException.", AmqpInactivityMonitor.this.toString()); - ASYNC_TASKS.execute(new Runnable() { - @Override - public void run() { - onException(new InactivityIOException( - "Channel was inactive for too (>" + (connectionTimeout) + ") long: " + next.getRemoteAddress())); + try { + ASYNC_TASKS.execute(new Runnable() { + @Override + public void run() { + onException(new InactivityIOException( + "Channel was inactive for too (>" + (connectionTimeout) + ") long: " + next.getRemoteAddress())); + } + }); + } catch (RejectedExecutionException ex) { + if (!ASYNC_TASKS.isShutdown()) { + LOG.error("Async connection timeout task was rejected from the executor: ", ex); + throw ex; } - }); + } } } }; @@ -76,26 +84,33 @@ public class AmqpInactivityMonitor extends TransportFilter { @Override public void run() { - if (keepAliveTask != null && !ASYNC_TASKS.isTerminating() && !ASYNC_TASKS.isTerminated()) { - ASYNC_TASKS.execute(new Runnable() { - @Override - public void run() { - try { - long nextIdleUpdate = amqpTransport.keepAlive(); - if (nextIdleUpdate > 0) { - synchronized (AmqpInactivityMonitor.this) { - if (keepAliveTask != null) { - keepAliveTask = new SchedulerTimerTask(keepAlive); - KEEPALIVE_TASK_TIMER.schedule(keepAliveTask, nextIdleUpdate); + if (keepAliveTask != null && !ASYNC_TASKS.isShutdown()) { + try { + ASYNC_TASKS.execute(new Runnable() { + @Override + public void run() { + try { + long nextIdleUpdate = amqpTransport.keepAlive(); + if (nextIdleUpdate > 0) { + synchronized (AmqpInactivityMonitor.this) { + if (keepAliveTask != null) { + keepAliveTask = new SchedulerTimerTask(keepAlive); + KEEPALIVE_TASK_TIMER.schedule(keepAliveTask, nextIdleUpdate); + } } } + } catch (Exception ex) { + onException(new InactivityIOException( + "Exception while performing idle checks for connection: " + next.getRemoteAddress())); } - } catch (Exception ex) { - onException(new InactivityIOException( - "Exception while performing idle checks for connection: " + next.getRemoteAddress())); } + }); + } catch (RejectedExecutionException ex) { + if (!ASYNC_TASKS.isShutdown()) { + LOG.error("Async connection timeout task was rejected from the executor: ", ex); + throw ex; } - }); + } } } }; @@ -144,7 +159,7 @@ public class AmqpInactivityMonitor extends TransportFilter { synchronized (AbstractInactivityMonitor.class) { if (CONNECTION_CHECK_TASK_COUNTER == 0) { - if (ASYNC_TASKS == null) { + if (ASYNC_TASKS == null || ASYNC_TASKS.isShutdown()) { ASYNC_TASKS = createExecutor(); } CONNECTION_CHECK_TASK_TIMER = new Timer("AMQP InactivityMonitor State Check", true); @@ -167,7 +182,7 @@ public class AmqpInactivityMonitor extends TransportFilter { synchronized (AbstractInactivityMonitor.class) { if (KEEPALIVE_TASK_COUNTER == 0) { - if (ASYNC_TASKS == null) { + if (ASYNC_TASKS == null || ASYNC_TASKS.isShutdown()) { ASYNC_TASKS = createExecutor(); } KEEPALIVE_TASK_TIMER = new Timer("AMQP InactivityMonitor Idle Update", true); diff --git a/activemq-client/src/main/java/org/apache/activemq/transport/AbstractInactivityMonitor.java b/activemq-client/src/main/java/org/apache/activemq/transport/AbstractInactivityMonitor.java index 8fbf623e92..7cc9205b93 100644 --- a/activemq-client/src/main/java/org/apache/activemq/transport/AbstractInactivityMonitor.java +++ b/activemq-client/src/main/java/org/apache/activemq/transport/AbstractInactivityMonitor.java @@ -30,7 +30,6 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; import org.apache.activemq.command.KeepAliveInfo; import org.apache.activemq.command.WireFormatInfo; import org.apache.activemq.thread.SchedulerTimerTask; -import org.apache.activemq.util.ThreadPoolUtils; import org.apache.activemq.wireformat.WireFormat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,9 +42,10 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { private static final Logger LOG = LoggerFactory.getLogger(AbstractInactivityMonitor.class); + private static final long DEFAULT_CHECK_TIME_MILLS = 30000; + private static ThreadPoolExecutor ASYNC_TASKS; private static int CHECKER_COUNTER; - private static long DEFAULT_CHECK_TIME_MILLS = 30000; private static Timer READ_CHECK_TIMER; private static Timer WRITE_CHECK_TIMER; @@ -61,9 +61,11 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { private final ReentrantReadWriteLock sendLock = new ReentrantReadWriteLock(); + private SchedulerTimerTask connectCheckerTask; private SchedulerTimerTask writeCheckerTask; private SchedulerTimerTask readCheckerTask; + private long connectAttemptTimeout = DEFAULT_CHECK_TIME_MILLS; private long readCheckTime = DEFAULT_CHECK_TIME_MILLS; private long writeCheckTime = DEFAULT_CHECK_TIME_MILLS; private long initialDelayTime = DEFAULT_CHECK_TIME_MILLS; @@ -72,6 +74,34 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { protected WireFormat wireFormat; + private final Runnable connectChecker = new Runnable() { + + private final long startTime = System.currentTimeMillis(); + + @Override + public void run() { + long now = System.currentTimeMillis(); + + if ((now - startTime) >= connectAttemptTimeout && connectCheckerTask != null && !ASYNC_TASKS.isShutdown()) { + LOG.debug("No connection attempt made in time for {}! Throwing InactivityIOException.", AbstractInactivityMonitor.this.toString()); + try { + ASYNC_TASKS.execute(new Runnable() { + @Override + public void run() { + onException(new InactivityIOException( + "Channel was inactive for too (>" + (connectAttemptTimeout) + ") long: " + next.getRemoteAddress())); + } + }); + } catch (RejectedExecutionException ex) { + if (!ASYNC_TASKS.isShutdown()) { + LOG.error("Async connection timeout task was rejected from the executor: ", ex); + throw ex; + } + } + } + } + }; + private final Runnable readChecker = new Runnable() { long lastRunTime; @@ -151,7 +181,7 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { return; } - if (!commandSent.get() && useKeepAlive && monitorStarted.get() && !ASYNC_TASKS.isTerminating() && !ASYNC_TASKS.isTerminated()) { + if (!commandSent.get() && useKeepAlive && monitorStarted.get() && !ASYNC_TASKS.isShutdown()) { LOG.trace("{} no message sent since last write check, sending a KeepAliveInfo", this); try { @@ -185,7 +215,7 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { }; }); } catch (RejectedExecutionException ex) { - if (!ASYNC_TASKS.isTerminating() && !ASYNC_TASKS.isTerminated()) { + if (!ASYNC_TASKS.isShutdown()) { LOG.error("Async write check was rejected from the executor: ", ex); throw ex; } @@ -204,7 +234,7 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { LOG.trace("A receive is in progress, skipping read check."); return; } - if (!commandReceived.get() && monitorStarted.get() && !ASYNC_TASKS.isTerminating() && !ASYNC_TASKS.isTerminated()) { + if (!commandReceived.get() && monitorStarted.get() && !ASYNC_TASKS.isShutdown()) { LOG.debug("No message received since last read check for {}. Throwing InactivityIOException.", this); try { @@ -221,7 +251,7 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { }; }); } catch (RejectedExecutionException ex) { - if (!ASYNC_TASKS.isTerminating() && !ASYNC_TASKS.isTerminated()) { + if (!ASYNC_TASKS.isShutdown()) { LOG.error("Async read check was rejected from the executor: ", ex); throw ex; } @@ -280,14 +310,14 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { // are performing a send we take a read lock. The inactivity monitor // sends its Heart-beat commands under a write lock. This means that // the MutexTransport is still responsible for synchronizing sends - this.sendLock.readLock().lock(); + sendLock.readLock().lock(); inSend.set(true); try { doOnewaySend(o); } finally { commandSent.set(true); inSend.set(false); - this.sendLock.readLock().unlock(); + sendLock.readLock().unlock(); } } @@ -319,6 +349,14 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { useKeepAlive = val; } + public long getConnectAttemptTimeout() { + return connectAttemptTimeout; + } + + public void setConnectAttemptTimeout(long connectionTimeout) { + this.connectAttemptTimeout = connectionTimeout; + } + public long getReadCheckTime() { return readCheckTime; } @@ -355,6 +393,52 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { return this.monitorStarted.get(); } + abstract protected boolean configuredOk() throws IOException; + + public synchronized void startConnectCheckTask() { + startConnectCheckTask(getConnectAttemptTimeout()); + } + + public synchronized void startConnectCheckTask(long connectionTimeout) { + if (connectionTimeout <= 0) { + return; + } + + LOG.info("Starting connection check task for: {}", this); + + this.connectAttemptTimeout = connectionTimeout; + + if (connectCheckerTask == null) { + connectCheckerTask = new SchedulerTimerTask(connectChecker); + + synchronized (AbstractInactivityMonitor.class) { + if (CHECKER_COUNTER == 0) { + if (ASYNC_TASKS == null || ASYNC_TASKS.isShutdown()) { + ASYNC_TASKS = createExecutor(); + } + if (READ_CHECK_TIMER == null) { + READ_CHECK_TIMER = new Timer("ActiveMQ InactivityMonitor ReadCheckTimer", true); + } + } + CHECKER_COUNTER++; + READ_CHECK_TIMER.schedule(connectCheckerTask, connectionTimeout); + } + } + } + + public synchronized void stopConnectCheckTask() { + if (connectCheckerTask != null) { + LOG.info("Stopping connection check task for: {}", this); + connectCheckerTask.cancel(); + connectCheckerTask = null; + + synchronized (AbstractInactivityMonitor.class) { + READ_CHECK_TIMER.purge(); + CHECKER_COUNTER--; + } + } + } + protected synchronized void startMonitorThreads() throws IOException { if (monitorStarted.get()) { return; @@ -375,11 +459,16 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { if (writeCheckTime > 0 || readCheckTime > 0) { monitorStarted.set(true); synchronized (AbstractInactivityMonitor.class) { - if (CHECKER_COUNTER == 0) { + if (ASYNC_TASKS == null || ASYNC_TASKS.isShutdown()) { ASYNC_TASKS = createExecutor(); + } + if (READ_CHECK_TIMER == null) { READ_CHECK_TIMER = new Timer("ActiveMQ InactivityMonitor ReadCheckTimer", true); + } + if (WRITE_CHECK_TIMER == null) { WRITE_CHECK_TIMER = new Timer("ActiveMQ InactivityMonitor WriteCheckTimer", true); } + CHECKER_COUNTER++; if (readCheckTime > 0) { READ_CHECK_TIMER.schedule(readCheckerTask, initialDelayTime, readCheckTime); @@ -391,9 +480,8 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { } } - abstract protected boolean configuredOk() throws IOException; - protected synchronized void stopMonitorThreads() { + stopConnectCheckTask(); if (monitorStarted.compareAndSet(true, false)) { if (readCheckerTask != null) { readCheckerTask.cancel(); @@ -401,6 +489,7 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { if (writeCheckerTask != null) { writeCheckerTask.cancel(); } + synchronized (AbstractInactivityMonitor.class) { WRITE_CHECK_TIMER.purge(); READ_CHECK_TIMER.purge(); @@ -410,7 +499,6 @@ public abstract class AbstractInactivityMonitor extends TransportFilter { READ_CHECK_TIMER.cancel(); WRITE_CHECK_TIMER = null; READ_CHECK_TIMER = null; - ThreadPoolUtils.shutdown(ASYNC_TASKS); } } } diff --git a/activemq-client/src/main/java/org/apache/activemq/transport/InactivityMonitor.java b/activemq-client/src/main/java/org/apache/activemq/transport/InactivityMonitor.java index 8b312eee24..288ec61eb6 100755 --- a/activemq-client/src/main/java/org/apache/activemq/transport/InactivityMonitor.java +++ b/activemq-client/src/main/java/org/apache/activemq/transport/InactivityMonitor.java @@ -44,7 +44,15 @@ public class InactivityMonitor extends AbstractInactivityMonitor { } } + @Override + public void start() throws Exception { + startConnectCheckTask(); + super.start(); + } + + @Override protected void processInboundWireFormatInfo(WireFormatInfo info) throws IOException { + stopConnectCheckTask(); IOException error = null; remoteWireFormatInfo = info; try { @@ -57,6 +65,7 @@ public class InactivityMonitor extends AbstractInactivityMonitor { } } + @Override protected void processOutboundWireFormatInfo(WireFormatInfo info) throws IOException{ localWireFormatInfo = info; startMonitorThreads(); diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java index 472561ac52..739e2fcf0a 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java @@ -70,6 +70,7 @@ public abstract class AbstractStompSocket extends TransportSupport implements St protected void doStart() throws Exception { socketTransportStarted.countDown(); stompInactivityMonitor.setTransportListener(getTransportListener()); + stompInactivityMonitor.startConnectCheckTask(); } //----- Abstract methods for subclasses to implement ---------------------// diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSConnectionTimeoutTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSConnectionTimeoutTest.java new file mode 100644 index 0000000000..b0ca372cb9 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSConnectionTimeoutTest.java @@ -0,0 +1,82 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Vector; +import java.util.concurrent.TimeUnit; + +import org.apache.activemq.util.Wait; +import org.eclipse.jetty.websocket.WebSocketClient; +import org.eclipse.jetty.websocket.WebSocketClientFactory; +import org.junit.Before; +import org.junit.Test; + +/** + * Test that a STOMP WS connection drops if not CONNECT or STOMP frame sent in time. + */ +public class StompWSConnectionTimeoutTest extends WSTransportTestSupport { + + protected WebSocketClient wsClient; + protected StompWSConnection wsStompConnection; + + protected Vector exceptions = new Vector(); + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + WebSocketClientFactory clientFactory = new WebSocketClientFactory(); + clientFactory.start(); + + wsClient = clientFactory.newWebSocketClient(); + wsStompConnection = new StompWSConnection(); + + wsClient.open(wsConnectUri, wsStompConnection); + if (!wsStompConnection.awaitConnection(30, TimeUnit.SECONDS)) { + throw new IOException("Could not connect to STOMP WS endpoint"); + } + } + + protected String getConnectorScheme() { + return "ws"; + } + + @Test(timeout = 90000) + public void testInactivityMonitor() throws Exception { + + assertTrue("one connection", Wait.waitFor(new Wait.Condition() { + @Override + public boolean isSatisified() throws Exception { + return 1 == broker.getTransportConnectorByScheme(getConnectorScheme()).connectionCount(); + } + }, TimeUnit.SECONDS.toMillis(15), TimeUnit.MILLISECONDS.toMillis(250))); + + // and it should be closed due to inactivity + assertTrue("no dangling connections", Wait.waitFor(new Wait.Condition() { + @Override + public boolean isSatisified() throws Exception { + return 0 == broker.getTransportConnectorByScheme(getConnectorScheme()).connectionCount(); + } + }, TimeUnit.SECONDS.toMillis(60), TimeUnit.MILLISECONDS.toMillis(500))); + + assertTrue("no exceptions", exceptions.isEmpty()); + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java index 9c4abc8a02..6ab86ff8a1 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java @@ -54,6 +54,7 @@ public class WSTransportTestSupport { @Before public void setUp() throws Exception { + LOG.info("========== Starting test: {} ==========", name.getMethodName()); broker = createBroker(true); } @@ -64,6 +65,8 @@ public class WSTransportTestSupport { } catch(Exception e) { LOG.warn("Error on Broker stop."); } + + LOG.info("========== Finished test: {} ==========", name.getMethodName()); } protected String getWSConnectorURI() { diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java index 28b6926567..aaad323dee 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java @@ -19,6 +19,7 @@ package org.apache.activemq.transport.mqtt; import java.io.IOException; import java.util.Timer; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; @@ -32,7 +33,6 @@ import org.apache.activemq.transport.AbstractInactivityMonitor; import org.apache.activemq.transport.InactivityIOException; import org.apache.activemq.transport.Transport; import org.apache.activemq.transport.TransportFilter; -import org.apache.activemq.util.ThreadPoolUtils; import org.apache.activemq.wireformat.WireFormat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,17 +69,25 @@ public class MQTTInactivityMonitor extends TransportFilter { long now = System.currentTimeMillis(); - if ((now - startTime) >= connectionTimeout && connectCheckerTask != null && !ASYNC_TASKS.isTerminating()) { + if ((now - startTime) >= connectionTimeout && connectCheckerTask != null && !ASYNC_TASKS.isShutdown()) { if (LOG.isDebugEnabled()) { LOG.debug("No CONNECT frame received in time for " + MQTTInactivityMonitor.this.toString() + "! Throwing InactivityIOException."); } - ASYNC_TASKS.execute(new Runnable() { - @Override - public void run() { - onException(new InactivityIOException("Channel was inactive for too (>" + (readKeepAliveTime + readGraceTime) + ") long: " - + next.getRemoteAddress())); + + try { + ASYNC_TASKS.execute(new Runnable() { + @Override + public void run() { + onException(new InactivityIOException("Channel was inactive for too (>" + (readKeepAliveTime + readGraceTime) + ") long: " + + next.getRemoteAddress())); + } + }); + } catch (RejectedExecutionException ex) { + if (!ASYNC_TASKS.isShutdown()) { + LOG.error("Async connection timeout task was rejected from the executor: ", ex); + throw ex; } - }); + } } } }; @@ -109,17 +117,24 @@ public class MQTTInactivityMonitor extends TransportFilter { return; } - if ((now - lastReceiveTime) >= readKeepAliveTime + readGraceTime && readCheckerTask != null && !ASYNC_TASKS.isTerminating()) { + if ((now - lastReceiveTime) >= readKeepAliveTime + readGraceTime && readCheckerTask != null && !ASYNC_TASKS.isShutdown()) { if (LOG.isDebugEnabled()) { LOG.debug("No message received since last read check for " + MQTTInactivityMonitor.this.toString() + "! Throwing InactivityIOException."); } - ASYNC_TASKS.execute(new Runnable() { - @Override - public void run() { - onException(new InactivityIOException("Channel was inactive for too (>" + - (connectionTimeout) + ") long: " + next.getRemoteAddress())); + try { + ASYNC_TASKS.execute(new Runnable() { + @Override + public void run() { + onException(new InactivityIOException("Channel was inactive for too (>" + + (connectionTimeout) + ") long: " + next.getRemoteAddress())); + } + }); + } catch (RejectedExecutionException ex) { + if (!ASYNC_TASKS.isShutdown()) { + LOG.error("Async connection timeout task was rejected from the executor: ", ex); + throw ex; } - }); + } } } }; @@ -215,7 +230,9 @@ public class MQTTInactivityMonitor extends TransportFilter { synchronized (AbstractInactivityMonitor.class) { if (CHECKER_COUNTER == 0) { - ASYNC_TASKS = createExecutor(); + if (ASYNC_TASKS == null || ASYNC_TASKS.isShutdown()) { + ASYNC_TASKS = createExecutor(); + } READ_CHECK_TIMER = new Timer("InactivityMonitor ReadCheck", true); } CHECKER_COUNTER++; @@ -230,7 +247,9 @@ public class MQTTInactivityMonitor extends TransportFilter { synchronized (AbstractInactivityMonitor.class) { if (CHECKER_COUNTER == 0) { - ASYNC_TASKS = createExecutor(); + if (ASYNC_TASKS == null || ASYNC_TASKS.isShutdown()) { + ASYNC_TASKS = createExecutor(); + } READ_CHECK_TIMER = new Timer("InactivityMonitor ReadCheck", true); } CHECKER_COUNTER++; @@ -250,8 +269,6 @@ public class MQTTInactivityMonitor extends TransportFilter { if (CHECKER_COUNTER == 0) { READ_CHECK_TIMER.cancel(); READ_CHECK_TIMER = null; - ThreadPoolUtils.shutdown(ASYNC_TASKS); - ASYNC_TASKS = null; } } } @@ -268,8 +285,6 @@ public class MQTTInactivityMonitor extends TransportFilter { if (CHECKER_COUNTER == 0) { READ_CHECK_TIMER.cancel(); READ_CHECK_TIMER = null; - ThreadPoolUtils.shutdown(ASYNC_TASKS); - ASYNC_TASKS = null; } } } diff --git a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompInactivityMonitor.java b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompInactivityMonitor.java index 8b1bc33202..fa2c408019 100755 --- a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompInactivityMonitor.java +++ b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompInactivityMonitor.java @@ -41,7 +41,9 @@ public class StompInactivityMonitor extends AbstractInactivityMonitor { public void startMonitoring() throws IOException { this.isConfigured = true; - this.startMonitorThreads(); + + stopConnectCheckTask(); + startMonitorThreads(); } @Override diff --git a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java index 87774dbc7b..f9d780f0ba 100644 --- a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java +++ b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompTransportFilter.java @@ -39,7 +39,9 @@ import org.slf4j.LoggerFactory; * @author chirino */ public class StompTransportFilter extends TransportFilter implements StompTransport { + private static final Logger TRACE = LoggerFactory.getLogger(StompTransportFilter.class.getPackage().getName() + ".StompIO"); + private final ProtocolConverter protocolConverter; private StompInactivityMonitor monitor; private StompWireFormat wireFormat; @@ -55,6 +57,14 @@ public class StompTransportFilter extends TransportFilter implements StompTransp } } + @Override + public void start() throws Exception { + if (monitor != null) { + monitor.startConnectCheckTask(getConnectAttemptTimeout()); + } + super.start(); + } + @Override public void oneway(Object o) throws IOException { try { @@ -168,12 +178,20 @@ public class StompTransportFilter extends TransportFilter implements StompTransp public int getMaxDataLength() { return wireFormat.getMaxDataLength(); } - + public void setMaxFrameSize(int maxFrameSize) { wireFormat.setMaxFrameSize(maxFrameSize); } - + public long getMaxFrameSize() { return wireFormat.getMaxFrameSize(); } + + public long getConnectAttemptTimeout() { + return wireFormat.getConnectionAttemptTimeout(); + } + + public void setConnectAttemptTimeout(long timeout) { + wireFormat.setConnectionAttemptTimeout(timeout); + } } diff --git a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java index 25ba91b364..daa4639839 100644 --- a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java +++ b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java @@ -45,15 +45,18 @@ public class StompWireFormat implements WireFormat { private static final int MAX_HEADER_LENGTH = 1024 * 10; private static final int MAX_HEADERS = 1000; private static final int MAX_DATA_LENGTH = 1024 * 1024 * 100; + public static final long DEFAULT_MAX_FRAME_SIZE = Long.MAX_VALUE; + public static final long DEFAULT_CONNECTION_TIMEOUT = 30000; private int version = 1; private int maxDataLength = MAX_DATA_LENGTH; private long maxFrameSize = DEFAULT_MAX_FRAME_SIZE; private String stompVersion = Stomp.DEFAULT_VERSION; - + private long connectionAttemptTimeout = DEFAULT_CONNECTION_TIMEOUT; + //The current frame size as it is unmarshalled from the stream - private AtomicLong frameSize = new AtomicLong(); + private final AtomicLong frameSize = new AtomicLong(); @Override public ByteSequence marshal(Object command) throws IOException { @@ -104,7 +107,7 @@ public class StompWireFormat implements WireFormat { public Object unmarshal(DataInput in) throws IOException { try { - + // parse action String action = parseAction(in, frameSize); @@ -131,7 +134,7 @@ public class StompWireFormat implements WireFormat { // We don't know how much to read.. data ends when we hit a 0 byte b; ByteArrayOutputStream baos = null; - while ((b = in.readByte()) != 0) { + while ((b = in.readByte()) != 0) { if (baos == null) { baos = new ByteArrayOutputStream(); } else if (baos.size() > getMaxDataLength()) { @@ -141,7 +144,7 @@ public class StompWireFormat implements WireFormat { throw new ProtocolException("The maximum frame size was exceeded", true); } } - + baos.write(b); } @@ -191,7 +194,7 @@ public class StompWireFormat implements WireFormat { protected String parseAction(DataInput in, AtomicLong frameSize) throws IOException { String action = null; - + // skip white space to next real action line while (true) { action = readLine(in, MAX_COMMAND_LENGTH, "The maximum command length was exceeded"); @@ -209,11 +212,11 @@ public class StompWireFormat implements WireFormat { } protected HashMap parseHeaders(DataInput in, AtomicLong frameSize) throws IOException { - HashMap headers = new HashMap(25); + HashMap headers = new HashMap(25); while (true) { ByteSequence line = readHeaderLine(in, MAX_HEADER_LENGTH, "The maximum header length was exceeded"); if (line != null && line.length > 1) { - + if (headers.size() > MAX_HEADERS) { throw new ProtocolException("The maximum number of headers was exceeded", true); } @@ -257,7 +260,7 @@ public class StompWireFormat implements WireFormat { } return headers; } - + protected int parseContentLength(String contentLength, AtomicLong frameSize) throws ProtocolException { int length; try { @@ -269,7 +272,7 @@ public class StompWireFormat implements WireFormat { if (length > getMaxDataLength()) { throw new ProtocolException("The maximum data length was exceeded", true); } - + if (frameSize.addAndGet(length) > getMaxFrameSize()) { throw new ProtocolException("The maximum frame size was exceeded", true); } @@ -341,7 +344,7 @@ public class StompWireFormat implements WireFormat { return new String(decoded.toByteArray(), "UTF-8"); } - + @Override public int getVersion() { return version; @@ -375,4 +378,12 @@ public class StompWireFormat implements WireFormat { public void setMaxFrameSize(long maxFrameSize) { this.maxFrameSize = maxFrameSize; } + + public long getConnectionAttemptTimeout() { + return connectionAttemptTimeout; + } + + public void setConnectionAttemptTimeout(long connectionAttemptTimeout) { + this.connectionAttemptTimeout = connectionAttemptTimeout; + } } diff --git a/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompConnectTimeoutTest.java b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompConnectTimeoutTest.java new file mode 100644 index 0000000000..69fd4deb38 --- /dev/null +++ b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompConnectTimeoutTest.java @@ -0,0 +1,172 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.stomp; + +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.net.Socket; +import java.util.Arrays; +import java.util.Collection; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLSocketFactory; + +import org.apache.activemq.util.Wait; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Test that connection attempts that don't send the connect performative + * get cleaned up by the inactivity monitor. + */ +@RunWith(Parameterized.class) +public class StompConnectTimeoutTest extends StompTestSupport { + + private static final Logger LOG = LoggerFactory.getLogger(StompConnectTimeoutTest.class); + + private Socket connection; + protected String connectorScheme; + + @Parameters(name="{0}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {"stomp"}, + {"stomp+ssl"}, + {"stomp+nio"}, + {"stomp+nio+ssl"} + }); + } + + public StompConnectTimeoutTest(String connectorScheme) { + this.connectorScheme = connectorScheme; + } + + protected String getConnectorScheme() { + return connectorScheme; + } + + @Override + public void tearDown() throws Exception { + if (connection != null) { + try { + connection.close(); + } catch (Throwable e) {} + connection = null; + } + super.tearDown(); + } + + @Override + public String getAdditionalConfig() { + return "?transport.connectAttemptTimeout=1200"; + } + + @Test(timeout = 15000) + public void testInactivityMonitor() throws Exception { + + Thread t1 = new Thread() { + + @Override + public void run() { + try { + connection = createSocket(); + connection.getOutputStream().write('S'); + connection.getOutputStream().flush(); + } catch (Exception ex) { + LOG.error("unexpected exception on connect/disconnect", ex); + exceptions.add(ex); + } + } + }; + + t1.start(); + + assertTrue("one connection", Wait.waitFor(new Wait.Condition() { + @Override + public boolean isSatisified() throws Exception { + return 1 == brokerService.getTransportConnectorByScheme(getConnectorScheme()).connectionCount(); + } + }, TimeUnit.SECONDS.toMillis(15), TimeUnit.MILLISECONDS.toMillis(250))); + + // and it should be closed due to inactivity + assertTrue("no dangling connections", Wait.waitFor(new Wait.Condition() { + @Override + public boolean isSatisified() throws Exception { + return 0 == brokerService.getTransportConnectorByScheme(getConnectorScheme()).connectionCount(); + } + }, TimeUnit.SECONDS.toMillis(15), TimeUnit.MILLISECONDS.toMillis(500))); + + assertTrue("no exceptions", exceptions.isEmpty()); + } + + @Override + protected boolean isUseTcpConnector() { + return connectorScheme.equalsIgnoreCase("stomp"); + } + + @Override + protected boolean isUseSslConnector() { + return connectorScheme.equalsIgnoreCase("stomp+ssl"); + } + + @Override + protected boolean isUseNioConnector() { + return connectorScheme.equalsIgnoreCase("stomp+nio"); + } + + @Override + protected boolean isUseNioPlusSslConnector() { + return connectorScheme.equalsIgnoreCase("stomp+nio+ssl"); + } + + @Override + protected Socket createSocket() throws IOException { + + boolean useSSL = false; + int port = 0; + + switch (connectorScheme) { + case "stomp": + port = this.port; + break; + case "stomp+ssl": + useSSL = true; + port = this.sslPort; + break; + case "stomp+nio": + port = this.nioPort; + break; + case "stomp+nio+ssl": + useSSL = true; + port = this.nioSslPort; + break; + default: + throw new IOException("Invalid STOMP connector scheme passed to test."); + } + + if (useSSL) { + return SSLSocketFactory.getDefault().createSocket("localhost", port); + } else { + return new Socket("localhost", port); + } + } +} diff --git a/activemq-stomp/src/test/resources/log4j.properties b/activemq-stomp/src/test/resources/log4j.properties index 7cc19418fd..f7c2c7fcc3 100755 --- a/activemq-stomp/src/test/resources/log4j.properties +++ b/activemq-stomp/src/test/resources/log4j.properties @@ -20,7 +20,7 @@ # log4j.rootLogger=INFO, out, stdout -#log4j.logger.org.apache.activemq.broker.scheduler=DEBUG +log4j.logger.org.apache.activemq.transport=DEBUG #log4j.logger.org.apache.activemq.network.DemandForwardingBridgeSupport=DEBUG #log4j.logger.org.apache.activemq.transport.failover=TRACE #log4j.logger.org.apache.activemq.store.jdbc=TRACE diff --git a/activemq-unit-tests/src/test/java/org/apache/activemq/openwire/OpenWireConnectionTimeoutTest.java b/activemq-unit-tests/src/test/java/org/apache/activemq/openwire/OpenWireConnectionTimeoutTest.java new file mode 100644 index 0000000000..28e09899eb --- /dev/null +++ b/activemq-unit-tests/src/test/java/org/apache/activemq/openwire/OpenWireConnectionTimeoutTest.java @@ -0,0 +1,255 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.openwire; + +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.IOException; +import java.net.Socket; +import java.security.SecureRandom; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import java.util.Vector; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; + +import org.apache.activemq.broker.BrokerService; +import org.apache.activemq.broker.TransportConnector; +import org.apache.activemq.spring.SpringSslContext; +import org.apache.activemq.util.Wait; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Test that connection attempts that don't send the WireFormatInfo performative + * get cleaned up by the inactivity monitor. + */ +@RunWith(Parameterized.class) +public class OpenWireConnectionTimeoutTest { + + private static final Logger LOG = LoggerFactory.getLogger(OpenWireConnectionTimeoutTest.class); + + @Rule public TestName name = new TestName(); + + private Socket connection; + protected String connectorScheme; + protected int port; + protected BrokerService brokerService; + protected Vector exceptions = new Vector(); + + @Parameters(name="{0}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {"tcp"}, + {"ssl"}, + {"nio"}, + {"nio+ssl"} + }); + } + + public OpenWireConnectionTimeoutTest(String connectorScheme) { + this.connectorScheme = connectorScheme; + } + + protected String getConnectorScheme() { + return connectorScheme; + } + + public String getTestName() { + return name.getMethodName(); + } + + @Before + public void setUp() throws Exception { + LOG.info("========== start " + getTestName() + " =========="); + + startBroker(); + } + + @After + public void tearDown() throws Exception { + if (connection != null) { + try { + connection.close(); + } catch (Throwable e) {} + connection = null; + } + + stopBroker(); + + LOG.info("========== start " + getTestName() + " =========="); + } + + public String getAdditionalConfig() { + return "?transport.connectAttemptTimeout=1200"; + } + + @Test(timeout = 90000) + public void testInactivityMonitor() throws Exception { + + Thread t1 = new Thread() { + + @Override + public void run() { + try { + connection = createConnection(); + connection.getOutputStream().write('A'); + connection.getOutputStream().flush(); + } catch (Exception ex) { + LOG.error("unexpected exception on connect/disconnect", ex); + exceptions.add(ex); + } + } + }; + + t1.start(); + + assertTrue("one connection", Wait.waitFor(new Wait.Condition() { + @Override + public boolean isSatisified() throws Exception { + return 1 == brokerService.getTransportConnectorByScheme(getConnectorScheme()).connectionCount(); + } + }, TimeUnit.SECONDS.toMillis(15), TimeUnit.MILLISECONDS.toMillis(250))); + + // and it should be closed due to inactivity + assertTrue("no dangling connections", Wait.waitFor(new Wait.Condition() { + @Override + public boolean isSatisified() throws Exception { + return 0 == brokerService.getTransportConnectorByScheme(getConnectorScheme()).connectionCount(); + } + }, TimeUnit.SECONDS.toMillis(15), TimeUnit.MILLISECONDS.toMillis(500))); + + assertTrue("no exceptions", exceptions.isEmpty()); + } + + protected Socket createConnection() throws IOException { + boolean useSsl = false; + + switch (connectorScheme) { + case "tcp": + case "nio": + break; + case "ssl": + case "nio+ssl": + useSsl = true;; + break; + default: + throw new IOException("Invalid OpenWire connector scheme passed to test."); + } + + if (useSsl) { + return SSLSocketFactory.getDefault().createSocket("localhost", port); + } else { + return new Socket("localhost", port); + } + } + + protected void startBroker() throws Exception { + brokerService = new BrokerService(); + brokerService.setPersistent(false); + brokerService.setSchedulerSupport(false); + brokerService.setAdvisorySupport(false); + brokerService.setUseJmx(false); + brokerService.getManagementContext().setCreateConnector(false); + + SSLContext ctx = SSLContext.getInstance("TLS"); + ctx.init(new KeyManager[0], new TrustManager[]{new DefaultTrustManager()}, new SecureRandom()); + SSLContext.setDefault(ctx); + + // Setup SSL context... + final File classesDir = new File(OpenWireConnectionTimeoutTest.class.getProtectionDomain().getCodeSource().getLocation().getFile()); + File keystore = new File(classesDir, "../../src/test/resources/server.keystore"); + final SpringSslContext sslContext = new SpringSslContext(); + sslContext.setKeyStore(keystore.getCanonicalPath()); + sslContext.setKeyStorePassword("password"); + sslContext.setTrustStore(keystore.getCanonicalPath()); + sslContext.setTrustStorePassword("password"); + sslContext.afterPropertiesSet(); + brokerService.setSslContext(sslContext); + + System.setProperty("javax.net.ssl.trustStore", keystore.getCanonicalPath()); + System.setProperty("javax.net.ssl.trustStorePassword", "password"); + System.setProperty("javax.net.ssl.trustStoreType", "jks"); + System.setProperty("javax.net.ssl.keyStore", keystore.getCanonicalPath()); + System.setProperty("javax.net.ssl.keyStorePassword", "password"); + System.setProperty("javax.net.ssl.keyStoreType", "jks"); + + TransportConnector connector = null; + + switch (connectorScheme) { + case "tcp": + connector = brokerService.addConnector("tcp://0.0.0.0:0" + getAdditionalConfig()); + break; + case "nio": + connector = brokerService.addConnector("nio://0.0.0.0:0" + getAdditionalConfig()); + break; + case "ssl": + connector = brokerService.addConnector("ssl://0.0.0.0:0" + getAdditionalConfig()); + break; + case "nio+ssl": + connector = brokerService.addConnector("nio+ssl://0.0.0.0:0" + getAdditionalConfig()); + break; + default: + throw new IOException("Invalid OpenWire connector scheme passed to test."); + } + + brokerService.start(); + brokerService.waitUntilStarted(); + + port = connector.getPublishableConnectURI().getPort(); + } + + public void stopBroker() throws Exception { + if (brokerService != null) { + brokerService.stop(); + brokerService.waitUntilStopped(); + brokerService = null; + } + } + + public class DefaultTrustManager implements X509TrustManager { + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + } +}