From 540999654a8d581e2ba7c4c160a522378b318bbf Mon Sep 17 00:00:00 2001 From: gtully Date: Wed, 21 Nov 2018 10:23:13 +0000 Subject: [PATCH] AMQ-7106 - fix pending stop support by avoiding sync through single shared status var - fix and test (cherry picked from commit 8cc0c5ad6c85381cf6bbeaf179086d451d96650e) --- .../activemq/broker/TransportConnection.java | 109 +++++----- .../transport/mqtt/MQTTInactivityMonitor.java | 2 +- ...pTransportInactiveDuringHandshakeTest.java | 203 ++++++++++++++++++ 3 files changed, 255 insertions(+), 59 deletions(-) create mode 100644 activemq-unit-tests/src/test/java/org/apache/activemq/transport/tcp/TcpTransportInactiveDuringHandshakeTest.java diff --git a/activemq-broker/src/main/java/org/apache/activemq/broker/TransportConnection.java b/activemq-broker/src/main/java/org/apache/activemq/broker/TransportConnection.java index 967e977a47..27b6072eaf 100644 --- a/activemq-broker/src/main/java/org/apache/activemq/broker/TransportConnection.java +++ b/activemq-broker/src/main/java/org/apache/activemq/broker/TransportConnection.java @@ -138,8 +138,14 @@ public class TransportConnection implements Connection, Task, CommandVisitor { private boolean blocked; private boolean connected; private boolean active; - private final AtomicBoolean starting = new AtomicBoolean(); - private final AtomicBoolean pendingStop = new AtomicBoolean(); + + // state management around pending stop + private static final int NEW = 0; + private static final int STARTING = 1; + private static final int STARTED = 2; + private static final int PENDING_STOP = 3; + private final AtomicInteger status = new AtomicInteger(NEW); + private long timeStamp; private final AtomicBoolean stopping = new AtomicBoolean(false); private final CountDownLatch stopped = new CountDownLatch(1); @@ -229,7 +235,7 @@ public class TransportConnection implements Connection, Task, CommandVisitor { } public void serviceTransportException(IOException e) { - if (!stopping.get() && !pendingStop.get()) { + if (!stopping.get() && status.get() != PENDING_STOP) { transportException.set(e); if (TRANSPORTLOG.isDebugEnabled()) { TRANSPORTLOG.debug(this + " failed: " + e, e); @@ -308,7 +314,7 @@ public class TransportConnection implements Connection, Task, CommandVisitor { } ConnectionError ce = new ConnectionError(); ce.setException(e); - if (pendingStop.get()) { + if (status.get() == PENDING_STOP) { dispatchSync(ce); } else { dispatchAsync(ce); @@ -326,7 +332,7 @@ public class TransportConnection implements Connection, Task, CommandVisitor { boolean responseRequired = command.isResponseRequired(); int commandId = command.getCommandId(); try { - if (!pendingStop.get()) { + if (status.get() != PENDING_STOP) { response = command.visit(this); } else { response = new ExceptionResponse(transportException.get()); @@ -998,7 +1004,7 @@ public class TransportConnection implements Connection, Task, CommandVisitor { @Override public boolean iterate() { try { - if (pendingStop.get() || stopping.get()) { + if (status.get() == PENDING_STOP || stopping.get()) { if (dispatchStopped.compareAndSet(false, true)) { if (transportException.get() == null) { try { @@ -1054,39 +1060,39 @@ public class TransportConnection implements Connection, Task, CommandVisitor { @Override public void start() throws Exception { - try { - synchronized (this) { - starting.set(true); - if (taskRunnerFactory != null) { - taskRunner = taskRunnerFactory.createTaskRunner(this, "ActiveMQ Connection Dispatcher: " - + getRemoteAddress()); - } else { - taskRunner = null; - } - transport.start(); - active = true; - BrokerInfo info = connector.getBrokerInfo().copy(); - if (connector.isUpdateClusterClients()) { - info.setPeerBrokerInfos(this.broker.getPeerBrokerInfos()); - } else { - info.setPeerBrokerInfos(null); - } - dispatchAsync(info); + if (status.compareAndSet(NEW, STARTING)) { + try { + synchronized (this) { + if (taskRunnerFactory != null) { + taskRunner = taskRunnerFactory.createTaskRunner(this, "ActiveMQ Connection Dispatcher: " + + getRemoteAddress()); + } else { + taskRunner = null; + } + transport.start(); + active = true; + BrokerInfo info = connector.getBrokerInfo().copy(); + if (connector.isUpdateClusterClients()) { + info.setPeerBrokerInfos(this.broker.getPeerBrokerInfos()); + } else { + info.setPeerBrokerInfos(null); + } + dispatchAsync(info); - connector.onStarted(this); - } - } catch (Exception e) { - // Force clean up on an error starting up. - pendingStop.set(true); - throw e; - } finally { - // stop() can be called from within the above block, - // but we want to be sure start() completes before - // stop() runs, so queue the stop until right now: - setStarting(false); - if (isPendingStop()) { - LOG.debug("Calling the delayed stop() after start() {}", this); - stop(); + connector.onStarted(this); + } + } catch (Exception e) { + // Force clean up on an error starting up. + status.set(PENDING_STOP); + throw e; + } finally { + // stop() can be called from within the above block, + // but we want to be sure start() completes before + // stop() runs, so queue the stop until right now: + if (!status.compareAndSet(STARTING, STARTED)) { + LOG.debug("Calling the delayed stop() after start() {}", this); + stop(); + } } } } @@ -1104,10 +1110,8 @@ public class TransportConnection implements Connection, Task, CommandVisitor { public void delayedStop(final int waitTime, final String reason, Throwable cause) { if (waitTime > 0) { - synchronized (this) { - pendingStop.set(true); - transportException.set(cause); - } + status.compareAndSet(STARTING, PENDING_STOP); + transportException.set(cause); try { stopTaskRunnerFactory.execute(new Runnable() { @Override @@ -1133,12 +1137,9 @@ public class TransportConnection implements Connection, Task, CommandVisitor { public void stopAsync() { // If we're in the middle of starting then go no further... for now. - synchronized (this) { - pendingStop.set(true); - if (starting.get()) { - LOG.debug("stopAsync() called in the middle of start(). Delaying till start completes.."); - return; - } + if (status.compareAndSet(STARTING, PENDING_STOP)) { + LOG.debug("stopAsync() called in the middle of start(). Delaying till start completes.."); + return; } if (stopping.compareAndSet(false, true)) { // Let all the connection contexts know we are shutting down @@ -1347,7 +1348,7 @@ public class TransportConnection implements Connection, Task, CommandVisitor { * @return true if the Connection is starting */ public boolean isStarting() { - return starting.get(); + return status.get() == STARTING; } @Override @@ -1360,19 +1361,11 @@ public class TransportConnection implements Connection, Task, CommandVisitor { return this.faultTolerantConnection; } - protected void setStarting(boolean starting) { - this.starting.set(starting); - } - /** * @return true if the Connection needs to stop */ public boolean isPendingStop() { - return pendingStop.get(); - } - - protected void setPendingStop(boolean pendingStop) { - this.pendingStop.set(pendingStop); + return status.get() == PENDING_STOP; } private NetworkBridgeConfiguration getNetworkConfiguration(final BrokerInfo info) throws IOException { diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java index 8c56a24c9e..b3d8fba38c 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java @@ -78,7 +78,7 @@ public class MQTTInactivityMonitor extends TransportFilter { ASYNC_TASKS.execute(new Runnable() { @Override public void run() { - onException(new InactivityIOException("Channel was inactive for too (>" + (readKeepAliveTime + readGraceTime) + ") long: " + onException(new InactivityIOException("CONNECT frame not received with in connectionTimeout (>" + connectionTimeout + "): " + next.getRemoteAddress())); } }); diff --git a/activemq-unit-tests/src/test/java/org/apache/activemq/transport/tcp/TcpTransportInactiveDuringHandshakeTest.java b/activemq-unit-tests/src/test/java/org/apache/activemq/transport/tcp/TcpTransportInactiveDuringHandshakeTest.java new file mode 100644 index 0000000000..d01511fef1 --- /dev/null +++ b/activemq-unit-tests/src/test/java/org/apache/activemq/transport/tcp/TcpTransportInactiveDuringHandshakeTest.java @@ -0,0 +1,203 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.tcp; + +import org.apache.activemq.broker.BrokerService; +import org.apache.activemq.broker.TransportConnector; +import org.apache.activemq.util.DefaultTestAppender; +import org.apache.activemq.util.Wait; +import org.apache.log4j.Level; +import org.apache.log4j.spi.LoggingEvent; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.*; +import java.io.IOException; +import java.io.OutputStream; +import java.net.URI; +import java.security.SecureRandom; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertTrue; + +public class TcpTransportInactiveDuringHandshakeTest { + + private static final org.slf4j.Logger LOG = LoggerFactory.getLogger(TcpTransportInactiveDuringHandshakeTest.class); + + public static final String KEYSTORE_TYPE = "jks"; + public static final String PASSWORD = "password"; + public static final String SERVER_KEYSTORE = "src/test/resources/server.keystore"; + public static final String TRUST_KEYSTORE = "src/test/resources/client.keystore"; + + static { + System.setProperty("javax.net.ssl.trustStore", TRUST_KEYSTORE); + System.setProperty("javax.net.ssl.trustStorePassword", PASSWORD); + System.setProperty("javax.net.ssl.trustStoreType", KEYSTORE_TYPE); + System.setProperty("javax.net.ssl.keyStore", SERVER_KEYSTORE); + System.setProperty("javax.net.ssl.keyStorePassword", PASSWORD); + System.setProperty("javax.net.ssl.keyStoreType", KEYSTORE_TYPE); + } + + private BrokerService brokerService; + private DefaultTestAppender appender; + CountDownLatch inactivityMonitorFired = new CountDownLatch(1); + CountDownLatch handShakeComplete = new CountDownLatch(1); + + @Before + public void before() throws Exception { + brokerService = new BrokerService(); + brokerService.setPersistent(false); + brokerService.setUseJmx(false); + + appender = new DefaultTestAppender() { + @Override + public void doAppend(LoggingEvent event) { + if (event.getLevel().equals(Level.WARN) && event.getRenderedMessage().contains("InactivityIOException")) { + inactivityMonitorFired.countDown(); + } + } + }; + org.apache.log4j.Logger rootLogger = org.apache.log4j.Logger.getRootLogger(); + rootLogger.addAppender(appender); + + } + + @After + public void after() throws Exception { + org.apache.log4j.Logger rootLogger = org.apache.log4j.Logger.getRootLogger(); + rootLogger.removeAppender(appender); + + if (brokerService != null) { + brokerService.stop(); + brokerService.waitUntilStopped(); + } + } + + @Test(timeout = 60000) + public void testInactivityMonitorThreadCompletesWhenFiringDuringStart() throws Exception { + brokerService.addConnector("mqtt+nio+ssl://localhost:0?transport.connectAttemptTimeout=1000&transport.closeAsync=false"); + brokerService.start(); + brokerService.waitUntilStarted(); + + TransportConnector transportConnector = brokerService.getTransportConnectors().get(0); + URI uri = transportConnector.getPublishableConnectURI(); + + + CountDownLatch blockHandShakeCompletion = new CountDownLatch(1); + + TrustManager[] trustManagers = new TrustManager[]{new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + LOG.info("Check Server Trusted: " + s, new Throwable("HERE")); + try { + blockHandShakeCompletion.await(20, TimeUnit.SECONDS); + } catch (InterruptedException e) { + e.printStackTrace(); + } + LOG.info("Check Server Trusted done!"); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }}; + + + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, trustManagers, new SecureRandom()); + + final SSLSocket sslSocket = (SSLSocket) sslContext.getSocketFactory().createSocket("127.0.0.1", uri.getPort()); + + sslSocket.addHandshakeCompletedListener(new HandshakeCompletedListener() { + @Override + public void handshakeCompleted(HandshakeCompletedEvent handshakeCompletedEvent) { + handShakeComplete.countDown(); + } + }); + + Executors.newCachedThreadPool().submit(new Runnable() { + @Override + public void run() { + try { + sslSocket.startHandshake(); + assertTrue("Socket connected", sslSocket.isConnected()); + } catch (IOException oops) { + oops.printStackTrace(); + } + + } + }); + + assertTrue("inactivity fired", inactivityMonitorFired.await(10, TimeUnit.SECONDS)); + + assertTrue("Found non blocked inactivity monitor thread - done its work", Wait.waitFor(new Wait.Condition() { + @Override + public boolean isSatisified() throws Exception { + // verify no InactivityMonitor Task blocked + Thread[] threads = new Thread[20]; + int activeCount = Thread.currentThread().getThreadGroup().enumerate(threads); + for (int i = 0; i