Add a connect check in the inactivity monitor to account for opened
connections that might drop but not be spotted, in the case where the
connect frame is lost this can lead to connections that aren't fully
opened and won't be cleaned up until the broker detects the socket has
failed.  

By default the connection timer is set to 30 seconds, if no connect
frame is read by then the connection is dropped.  The broker can be
configured via the 'transport.connectAttemptTimeout' URI option, a value
<= zero disable the check.
This commit is contained in:
Timothy Bish 2015-01-05 18:53:34 -05:00
parent 6c2e2f5446
commit 4b7131ff85
6 changed files with 247 additions and 44 deletions

View File

@ -47,7 +47,6 @@ public class MQTTInactivityMonitor extends TransportFilter {
private static int CHECKER_COUNTER; private static int CHECKER_COUNTER;
private static Timer READ_CHECK_TIMER; private static Timer READ_CHECK_TIMER;
private final AtomicBoolean monitorStarted = new AtomicBoolean(false);
private final AtomicBoolean failed = new AtomicBoolean(false); private final AtomicBoolean failed = new AtomicBoolean(false);
private final AtomicBoolean inReceive = new AtomicBoolean(false); private final AtomicBoolean inReceive = new AtomicBoolean(false);
private final AtomicInteger lastReceiveCounter = new AtomicInteger(0); private final AtomicInteger lastReceiveCounter = new AtomicInteger(0);
@ -57,9 +56,34 @@ public class MQTTInactivityMonitor extends TransportFilter {
private long readGraceTime = DEFAULT_CHECK_TIME_MILLS; private long readGraceTime = DEFAULT_CHECK_TIME_MILLS;
private long readKeepAliveTime = DEFAULT_CHECK_TIME_MILLS; private long readKeepAliveTime = DEFAULT_CHECK_TIME_MILLS;
private boolean keepAliveResponseRequired;
private MQTTProtocolConverter protocolConverter; private MQTTProtocolConverter protocolConverter;
private long connectionTimeout = MQTTWireFormat.DEFAULT_CONNECTION_TIMEOUT;
private SchedulerTimerTask connectCheckerTask;
private final Runnable connectChecker = new Runnable() {
private final long startTime = System.currentTimeMillis();
@Override
public void run() {
long now = System.currentTimeMillis();
if ((now - startTime) >= connectionTimeout && connectCheckerTask != null && !ASYNC_TASKS.isTerminating()) {
if (LOG.isDebugEnabled()) {
LOG.debug("No CONNECT frame received in time for " + MQTTInactivityMonitor.this.toString() + "! Throwing InactivityIOException.");
}
ASYNC_TASKS.execute(new Runnable() {
@Override
public void run() {
onException(new InactivityIOException("Channel was inactive for too (>" + (readKeepAliveTime + readGraceTime) + ") long: "
+ next.getRemoteAddress()));
}
});
}
}
};
private final Runnable readChecker = new Runnable() { private final Runnable readChecker = new Runnable() {
long lastReceiveTime = System.currentTimeMillis(); long lastReceiveTime = System.currentTimeMillis();
@ -85,15 +109,15 @@ public class MQTTInactivityMonitor extends TransportFilter {
return; return;
} }
if ((now - lastReceiveTime) >= readKeepAliveTime + readGraceTime && monitorStarted.get() && !ASYNC_TASKS.isTerminating()) { if ((now - lastReceiveTime) >= readKeepAliveTime + readGraceTime && readCheckerTask != null && !ASYNC_TASKS.isTerminating()) {
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("No message received since last read check for " + MQTTInactivityMonitor.this.toString() + "! Throwing InactivityIOException."); LOG.debug("No message received since last read check for " + MQTTInactivityMonitor.this.toString() + "! Throwing InactivityIOException.");
} }
ASYNC_TASKS.execute(new Runnable() { ASYNC_TASKS.execute(new Runnable() {
@Override @Override
public void run() { public void run() {
onException(new InactivityIOException("Channel was inactive for too (>" + (readKeepAliveTime + readGraceTime) + ") long: " onException(new InactivityIOException("Channel was inactive for too (>" +
+ next.getRemoteAddress())); (connectionTimeout) + ") long: " + next.getRemoteAddress()));
} }
}); });
} }
@ -107,12 +131,12 @@ public class MQTTInactivityMonitor extends TransportFilter {
@Override @Override
public void start() throws Exception { public void start() throws Exception {
next.start(); next.start();
startMonitorThread();
} }
@Override @Override
public void stop() throws Exception { public void stop() throws Exception {
stopMonitorThread(); stopReadChecker();
stopConnectChecker();
next.stop(); next.stop();
} }
@ -149,7 +173,8 @@ public class MQTTInactivityMonitor extends TransportFilter {
@Override @Override
public void onException(IOException error) { public void onException(IOException error) {
if (failed.compareAndSet(false, true)) { if (failed.compareAndSet(false, true)) {
stopMonitorThread(); stopConnectChecker();
stopReadChecker();
if (protocolConverter != null) { if (protocolConverter != null) {
protocolConverter.onTransportError(); protocolConverter.onTransportError();
} }
@ -173,18 +198,6 @@ public class MQTTInactivityMonitor extends TransportFilter {
this.readKeepAliveTime = readKeepAliveTime; this.readKeepAliveTime = readKeepAliveTime;
} }
public boolean isKeepAliveResponseRequired() {
return this.keepAliveResponseRequired;
}
public void setKeepAliveResponseRequired(boolean value) {
this.keepAliveResponseRequired = value;
}
public boolean isMonitorStarted() {
return this.monitorStarted.get();
}
public void setProtocolConverter(MQTTProtocolConverter protocolConverter) { public void setProtocolConverter(MQTTProtocolConverter protocolConverter) {
this.protocolConverter = protocolConverter; this.protocolConverter = protocolConverter;
} }
@ -193,41 +206,61 @@ public class MQTTInactivityMonitor extends TransportFilter {
return protocolConverter; return protocolConverter;
} }
synchronized void startMonitorThread() { synchronized void startConnectChecker(long connectionTimeout) {
this.connectionTimeout = connectionTimeout;
if (connectionTimeout > 0 && connectCheckerTask == null) {
connectCheckerTask = new SchedulerTimerTask(connectChecker);
// Not yet configured if this isn't set yet. long connectionCheckInterval = Math.min(connectionTimeout, 1000);
if (protocolConverter == null) {
return;
}
if (monitorStarted.get()) {
return;
}
if (readKeepAliveTime > 0) {
readCheckerTask = new SchedulerTimerTask(readChecker);
}
if (readKeepAliveTime > 0) {
monitorStarted.set(true);
synchronized (AbstractInactivityMonitor.class) { synchronized (AbstractInactivityMonitor.class) {
if (CHECKER_COUNTER == 0) { if (CHECKER_COUNTER == 0) {
ASYNC_TASKS = createExecutor(); ASYNC_TASKS = createExecutor();
READ_CHECK_TIMER = new Timer("InactivityMonitor ReadCheck", true); READ_CHECK_TIMER = new Timer("InactivityMonitor ReadCheck", true);
} }
CHECKER_COUNTER++; CHECKER_COUNTER++;
if (readKeepAliveTime > 0) { READ_CHECK_TIMER.schedule(connectCheckerTask, connectionCheckInterval, connectionCheckInterval);
READ_CHECK_TIMER.schedule(readCheckerTask, readKeepAliveTime, readGraceTime); }
}
}
synchronized void startReadChecker() {
if (readKeepAliveTime > 0 && readCheckerTask == null) {
readCheckerTask = new SchedulerTimerTask(readChecker);
synchronized (AbstractInactivityMonitor.class) {
if (CHECKER_COUNTER == 0) {
ASYNC_TASKS = createExecutor();
READ_CHECK_TIMER = new Timer("InactivityMonitor ReadCheck", true);
}
CHECKER_COUNTER++;
READ_CHECK_TIMER.schedule(readCheckerTask, readKeepAliveTime, readGraceTime);
}
}
}
synchronized void stopConnectChecker() {
if (connectCheckerTask != null) {
connectCheckerTask.cancel();
connectCheckerTask = null;
synchronized (AbstractInactivityMonitor.class) {
READ_CHECK_TIMER.purge();
CHECKER_COUNTER--;
if (CHECKER_COUNTER == 0) {
READ_CHECK_TIMER.cancel();
READ_CHECK_TIMER = null;
ThreadPoolUtils.shutdown(ASYNC_TASKS);
ASYNC_TASKS = null;
} }
} }
} }
} }
synchronized void stopMonitorThread() { synchronized void stopReadChecker() {
if (monitorStarted.compareAndSet(true, false)) { if (readCheckerTask != null) {
if (readCheckerTask != null) { readCheckerTask.cancel();
readCheckerTask.cancel(); readCheckerTask = null;
}
synchronized (AbstractInactivityMonitor.class) { synchronized (AbstractInactivityMonitor.class) {
READ_CHECK_TIMER.purge(); READ_CHECK_TIMER.purge();

View File

@ -625,6 +625,9 @@ public class MQTTProtocolConverter {
return; return;
} }
// Client has sent a valid CONNECT frame, we can stop the connect checker.
monitor.stopConnectChecker();
long keepAliveMS = keepAliveSeconds * 1000; long keepAliveMS = keepAliveSeconds * 1000;
LOG.debug("MQTT Client {} requests heart beat of {} ms", getClientId(), keepAliveMS); LOG.debug("MQTT Client {} requests heart beat of {} ms", getClientId(), keepAliveMS);
@ -642,7 +645,7 @@ public class MQTTProtocolConverter {
monitor.setProtocolConverter(this); monitor.setProtocolConverter(this);
monitor.setReadKeepAliveTime(keepAliveMS); monitor.setReadKeepAliveTime(keepAliveMS);
monitor.setReadGraceTime(readGracePeriod); monitor.setReadGraceTime(readGracePeriod);
monitor.startMonitorThread(); monitor.startReadChecker();
LOG.debug("MQTT Client {} established heart beat of {} ms ({} ms + {} ms grace period)", LOG.debug("MQTT Client {} established heart beat of {} ms ({} ms + {} ms grace period)",
new Object[] { getClientId(), keepAliveMS, keepAliveMS, readGracePeriod }); new Object[] { getClientId(), keepAliveMS, keepAliveMS, readGracePeriod });

View File

@ -60,6 +60,7 @@ public class MQTTTransportFilter extends TransportFilter implements MQTTTranspor
private MQTTInactivityMonitor monitor; private MQTTInactivityMonitor monitor;
private MQTTWireFormat wireFormat; private MQTTWireFormat wireFormat;
private final AtomicBoolean stopped = new AtomicBoolean(); private final AtomicBoolean stopped = new AtomicBoolean();
private long connectAttemptTimeout = MQTTWireFormat.DEFAULT_CONNECTION_TIMEOUT;
private boolean trace; private boolean trace;
private final Object sendLock = new Object(); private final Object sendLock = new Object();
@ -148,9 +149,17 @@ public class MQTTTransportFilter extends TransportFilter implements MQTTTranspor
} }
} }
@Override
public void start() throws Exception {
if (monitor != null) {
monitor.startConnectChecker(getConnectAttemptTimeout());
}
super.start();
}
@Override @Override
public void stop() throws Exception { public void stop() throws Exception {
if( stopped.compareAndSet(false, true) ) { if (stopped.compareAndSet(false, true)) {
super.stop(); super.stop();
} }
} }
@ -203,6 +212,24 @@ public class MQTTTransportFilter extends TransportFilter implements MQTTTranspor
protocolConverter.setDefaultKeepAlive(defaultHeartBeat); protocolConverter.setDefaultKeepAlive(defaultHeartBeat);
} }
/**
* @return the timeout value used to fail a connection if no CONNECT frame read.
*/
public long getConnectAttemptTimeout() {
return connectAttemptTimeout;
}
/**
* Sets the timeout value used to fail a connection if no CONNECT frame is read
* in the given interval.
*
* @param connectTimeout
* the connection frame received timeout value.
*/
public void setConnectAttemptTimeout(long connectTimeout) {
this.connectAttemptTimeout = connectTimeout;
}
public boolean getPublishDollarTopics() { public boolean getPublishDollarTopics() {
return protocolConverter != null && protocolConverter.getPublishDollarTopics(); return protocolConverter != null && protocolConverter.getPublishDollarTopics();
} }

View File

@ -36,9 +36,11 @@ import org.fusesource.mqtt.codec.MQTTFrame;
public class MQTTWireFormat implements WireFormat { public class MQTTWireFormat implements WireFormat {
static final int MAX_MESSAGE_LENGTH = 1024 * 1024 * 256; static final int MAX_MESSAGE_LENGTH = 1024 * 1024 * 256;
static final long DEFAULT_CONNECTION_TIMEOUT = 30000L;
private int version = 1; private int version = 1;
@Override
public ByteSequence marshal(Object command) throws IOException { public ByteSequence marshal(Object command) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos); DataOutputStream dos = new DataOutputStream(baos);
@ -47,12 +49,14 @@ public class MQTTWireFormat implements WireFormat {
return baos.toByteSequence(); return baos.toByteSequence();
} }
@Override
public Object unmarshal(ByteSequence packet) throws IOException { public Object unmarshal(ByteSequence packet) throws IOException {
ByteArrayInputStream stream = new ByteArrayInputStream(packet); ByteArrayInputStream stream = new ByteArrayInputStream(packet);
DataInputStream dis = new DataInputStream(stream); DataInputStream dis = new DataInputStream(stream);
return unmarshal(dis); return unmarshal(dis);
} }
@Override
public void marshal(Object command, DataOutput dataOut) throws IOException { public void marshal(Object command, DataOutput dataOut) throws IOException {
MQTTFrame frame = (MQTTFrame) command; MQTTFrame frame = (MQTTFrame) command;
dataOut.write(frame.header()); dataOut.write(frame.header());
@ -74,6 +78,7 @@ public class MQTTWireFormat implements WireFormat {
} }
} }
@Override
public Object unmarshal(DataInput dataIn) throws IOException { public Object unmarshal(DataInput dataIn) throws IOException {
byte header = dataIn.readByte(); byte header = dataIn.readByte();
@ -107,6 +112,7 @@ public class MQTTWireFormat implements WireFormat {
/** /**
* @param the version of the wire format * @param the version of the wire format
*/ */
@Override
public void setVersion(int version) { public void setVersion(int version) {
this.version = version; this.version = version;
} }
@ -114,6 +120,7 @@ public class MQTTWireFormat implements WireFormat {
/** /**
* @return the version of the wire format * @return the version of the wire format
*/ */
@Override
public int getVersion() { public int getVersion() {
return this.version; return this.version;
} }

View File

@ -0,0 +1,124 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.mqtt;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.net.Socket;
import java.util.Arrays;
import java.util.Collection;
import javax.net.ssl.SSLSocketFactory;
import org.apache.activemq.util.Wait;
import org.junit.After;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Test that connection attempts that don't send a CONNECT frame will
* get cleaned up by the inactivity monitor.
*/
@RunWith(Parameterized.class)
public class MQTTConnectTest extends MQTTTestSupport {
private static final Logger LOG = LoggerFactory.getLogger(MQTTConnectTest.class);
private Socket connection;
@Parameters(name="{0}")
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
{"mqtt", false},
{"mqtt+ssl", true},
{"mqtt+nio", false},
{"mqtt+nio+ssl", true}
});
}
public MQTTConnectTest(String connectorScheme, boolean useSSL) {
super(connectorScheme, useSSL);
}
@Override
@After
public void tearDown() throws Exception {
if (connection != null) {
try {
connection.close();
} catch (Throwable e) {}
connection = null;
}
super.tearDown();
}
@Override
public String getProtocolConfig() {
return "transport.connectAttemptTimeout=2000";
}
@Test(timeout = 60 * 1000)
public void testInactivityMonitor() throws Exception {
Thread t1 = new Thread() {
@Override
public void run() {
try {
connection = createConnection();
connection.getOutputStream().write(0);
connection.getOutputStream().flush();
} catch (Exception ex) {
LOG.error("unexpected exception on connect/disconnect", ex);
exceptions.add(ex);
}
}
};
t1.start();
assertTrue("one connection", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return 1 == brokerService.getTransportConnectors().get(0).connectionCount();
}
}));
// and it should be closed due to inactivity
assertTrue("no dangling connections", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return 0 == brokerService.getTransportConnectors().get(0).connectionCount();
}
}));
assertTrue("no exceptions", exceptions.isEmpty());
}
protected Socket createConnection() throws IOException {
if (isUseSSL()) {
return SSLSocketFactory.getDefault().createSocket("localhost", port);
} else {
return new Socket("localhost", port);
}
}
}

View File

@ -212,6 +212,7 @@ public class MQTTTestSupport {
StringBuilder connectorURI = new StringBuilder(); StringBuilder connectorURI = new StringBuilder();
connectorURI.append(getProtocolScheme()); connectorURI.append(getProtocolScheme());
connectorURI.append("://0.0.0.0:").append(port); connectorURI.append("://0.0.0.0:").append(port);
String protocolConfig = getProtocolConfig();
if (protocolConfig != null && !protocolConfig.isEmpty()) { if (protocolConfig != null && !protocolConfig.isEmpty()) {
connectorURI.append("?").append(protocolConfig); connectorURI.append("?").append(protocolConfig);
} }
@ -291,6 +292,14 @@ public class MQTTTestSupport {
this.protocolScheme = scheme; this.protocolScheme = scheme;
} }
public String getProtocolConfig() {
return protocolConfig;
}
public void setProtocolConfig(String config) {
this.protocolConfig = config;
}
public boolean isUseSSL() { public boolean isUseSSL() {
return this.useSSL; return this.useSSL;
} }