From 446ff61542f47f50c2299d8ef1cae8fe2b98a5ad Mon Sep 17 00:00:00 2001 From: Justin Bertram Date: Thu, 7 Apr 2022 16:33:29 -0500 Subject: [PATCH] ARTEMIS-3770 refactor MQTT handling of client ID It would be useful for security manager implementations to be able to alter the client ID of MQTT connections. This commit supports this functionality by moving the code which handles the client ID *ahead* of the authentication code. There it sets the client ID on the connection and thereafter any component (e.g. security managers) which needs to inspect or modify it can do so on the connection. This commit also refactors the MQTT connection class to extend the abstract connection class. This greatly simplifies the MQTT connection class and will make it easier to maintain in the future. --- .../protocol/AbstractRemotingConnection.java | 12 +- .../core/protocol/mqtt/MQTTConnection.java | 182 ++---------------- .../protocol/mqtt/MQTTConnectionManager.java | 59 +----- .../protocol/mqtt/MQTTProtocolHandler.java | 151 ++++++++++----- .../protocol/mqtt/MQTTProtocolManager.java | 2 +- .../protocol/mqtt/MQTTRoutingContext.java | 2 +- .../mqtt/MQTTSecurityManagerTest.java | 152 +++++++++++++++ 7 files changed, 287 insertions(+), 273 deletions(-) create mode 100644 tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/MQTTSecurityManagerTest.java diff --git a/artemis-core-client/src/main/java/org/apache/activemq/artemis/spi/core/protocol/AbstractRemotingConnection.java b/artemis-core-client/src/main/java/org/apache/activemq/artemis/spi/core/protocol/AbstractRemotingConnection.java index f69d1d2a82..cbd235782e 100644 --- a/artemis-core-client/src/main/java/org/apache/activemq/artemis/spi/core/protocol/AbstractRemotingConnection.java +++ b/artemis-core-client/src/main/java/org/apache/activemq/artemis/spi/core/protocol/AbstractRemotingConnection.java @@ -161,20 +161,16 @@ public abstract class AbstractRemotingConnection implements RemotingConnection { @Override public List removeCloseListeners() { - List ret = new ArrayList<>(closeListeners); - + List deletedCloseListeners = new ArrayList<>(closeListeners); closeListeners.clear(); - - return ret; + return deletedCloseListeners; } @Override public List removeFailureListeners() { - List ret = getFailureListeners(); - + List deletedFailureListeners = getFailureListeners(); failureListeners.clear(); - - return ret; + return deletedFailureListeners; } @Override diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnection.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnection.java index 16b2a7f125..8a98e69dce 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnection.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnection.java @@ -19,153 +19,43 @@ package org.apache.activemq.artemis.core.protocol.mqtt; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; -import java.util.concurrent.atomic.AtomicBoolean; import org.apache.activemq.artemis.api.core.ActiveMQBuffer; import org.apache.activemq.artemis.api.core.ActiveMQException; import org.apache.activemq.artemis.api.core.SimpleString; -import org.apache.activemq.artemis.core.remoting.CloseListener; import org.apache.activemq.artemis.core.remoting.FailureListener; -import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection; +import org.apache.activemq.artemis.spi.core.protocol.AbstractRemotingConnection; import org.apache.activemq.artemis.spi.core.remoting.Connection; -import org.apache.activemq.artemis.spi.core.remoting.ReadyListener; -import javax.security.auth.Subject; - -public class MQTTConnection implements RemotingConnection { - - private final Connection transportConnection; - - private final long creationTime; - - private AtomicBoolean dataReceived; +public class MQTTConnection extends AbstractRemotingConnection { private boolean destroyed; private boolean connected; - private String clientID; - - private final List failureListeners = new CopyOnWriteArrayList<>(); - - private final List closeListeners = new CopyOnWriteArrayList<>(); - - private Subject subject; - private int receiveMaximum = -1; private String protocolVersion; + private boolean clientIdAssignedByBroker = false; + public MQTTConnection(Connection transportConnection) throws Exception { - this.transportConnection = transportConnection; - this.creationTime = System.currentTimeMillis(); - this.dataReceived = new AtomicBoolean(); + super(transportConnection, null); this.destroyed = false; transportConnection.setProtocolConnection(this); } - - @Override - public void scheduledFlush() { - flush(); - } - - @Override - public boolean isWritable(ReadyListener callback) { - return transportConnection.isWritable(callback) && transportConnection.isOpen(); - } - - @Override - public Object getID() { - return transportConnection.getID(); - } - - @Override - public long getCreationTime() { - return creationTime; - } - - @Override - public String getRemoteAddress() { - return transportConnection.getRemoteAddress(); - } - - @Override - public void addFailureListener(FailureListener listener) { - failureListeners.add(listener); - } - - @Override - public boolean removeFailureListener(FailureListener listener) { - return failureListeners.remove(listener); - } - - @Override - public void addCloseListener(CloseListener listener) { - closeListeners.add(listener); - } - - @Override - public boolean removeCloseListener(CloseListener listener) { - return closeListeners.remove(listener); - } - - @Override - public List removeCloseListeners() { - List deletedCloseListeners = copyCloseListeners(); - closeListeners.clear(); - return deletedCloseListeners; - } - - @Override - public void setCloseListeners(List listeners) { - closeListeners.clear(); - closeListeners.addAll(listeners); - } - - @Override - public List getFailureListeners() { - return failureListeners; - } - - @Override - public List removeFailureListeners() { - List deletedFailureListeners = copyFailureListeners(); - failureListeners.clear(); - return deletedFailureListeners; - } - - @Override - public void setFailureListeners(List listeners) { - failureListeners.clear(); - failureListeners.addAll(listeners); - } - - @Override - public ActiveMQBuffer createTransportBuffer(int size) { - return transportConnection.createTransportBuffer(size); - } - @Override public void fail(ActiveMQException me) { - List copy = copyFailureListeners(); + List copy = new ArrayList<>(failureListeners); for (FailureListener listener : copy) { listener.connectionFailed(me, false); } transportConnection.close(); } - private List copyFailureListeners() { - return new ArrayList<>(failureListeners); - } - - private List copyCloseListeners() { - return new ArrayList<>(closeListeners); - } - @Override public void fail(ActiveMQException me, String scaleDownTargetNodeID) { synchronized (failureListeners) { @@ -198,11 +88,6 @@ public class MQTTConnection implements RemotingConnection { disconnect(false); } - @Override - public Connection getTransportConnection() { - return transportConnection; - } - @Override public boolean isClient() { return false; @@ -224,12 +109,7 @@ public class MQTTConnection implements RemotingConnection { } protected void dataReceived() { - dataReceived.set(true); - } - - @Override - public boolean checkDataReceived() { - return dataReceived.compareAndSet(true, false); + dataReceived = true; } @Override @@ -254,31 +134,11 @@ public class MQTTConnection implements RemotingConnection { //unsupported } - @Override - public boolean isSupportReconnect() { - return false; - } - @Override public boolean isSupportsFlowControl() { return false; } - @Override - public void setAuditSubject(Subject subject) { - this.subject = subject; - } - - @Override - public Subject getAuditSubject() { - return subject; - } - - @Override - public Subject getSubject() { - return null; - } - /** * Returns the name of the protocol for this Remoting Connection * @@ -289,26 +149,6 @@ public class MQTTConnection implements RemotingConnection { return MQTTProtocolManagerFactory.MQTT_PROTOCOL_NAME + (protocolVersion != null ? protocolVersion : ""); } - /** - * Sets the client ID associated with this connection - * - * @param cID - */ - @Override - public void setClientID(String cID) { - this.clientID = cID; - } - - /** - * Returns the Client ID associated with this connection - * - * @return - */ - @Override - public String getClientID() { - return clientID; - } - @Override public String getTransportLocalAddress() { return getTransportConnection().getLocalAddress(); @@ -325,4 +165,12 @@ public class MQTTConnection implements RemotingConnection { public void setProtocolVersion(String protocolVersion) { this.protocolVersion = protocolVersion; } + + public void setClientIdAssignedByBroker(boolean clientIdAssignedByBroker) { + this.clientIdAssignedByBroker = clientIdAssignedByBroker; + } + + public boolean isClientIdAssignedByBroker() { + return clientIdAssignedByBroker; + } } diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java index 91b9fb8fc7..3a68ffb0a3 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java @@ -17,15 +17,12 @@ package org.apache.activemq.artemis.core.protocol.mqtt; -import java.util.UUID; import java.util.List; import io.netty.buffer.ByteBufAllocator; import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttProperties; import io.netty.handler.codec.mqtt.MqttVersion; -import io.netty.util.CharsetUtil; -import org.apache.activemq.artemis.api.core.Pair; import org.apache.activemq.artemis.api.core.client.ActiveMQClient; import org.apache.activemq.artemis.core.server.ActiveMQServer; import org.apache.activemq.artemis.core.server.ServerSession; @@ -58,7 +55,7 @@ public class MQTTConnectionManager { session.getConnection().addFailureListener(failureListener); } - void connect(MqttConnectMessage connect, String validatedUser) throws Exception { + void connect(MqttConnectMessage connect, String validatedUser, String username, String password) throws Exception { if (session.getVersion() == MQTTVersion.MQTT_5) { session.getConnection().setProtocolVersion(Byte.toString(MqttVersion.MQTT_5.protocolLevel())); String authenticationMethod = MQTTUtil.getProperty(String.class, connect.variableHeader().properties(), AUTHENTICATION_METHOD); @@ -70,32 +67,14 @@ public class MQTTConnectionManager { } } - String password = connect.payload().passwordInBytes() == null ? null : new String( connect.payload().passwordInBytes(), CharsetUtil.UTF_8); - String username = connect.payload().userName(); - // the Netty codec uses "CleanSession" for both 3.1.1 "clean session" and 5 "clean start" which have slightly different semantics boolean cleanStart = connect.variableHeader().isCleanSession(); - Pair clientIdValidation = validateClientId(connect.payload().clientIdentifier(), cleanStart); - if (clientIdValidation == null) { - // this represents an invalid client ID for MQTT 5 clients - session.getProtocolHandler().sendConnack(MQTTReasonCodes.CLIENT_IDENTIFIER_NOT_VALID); - disconnect(true); - return; - } else if (clientIdValidation.getA() == null) { - // this represents an invalid client ID for MQTT 3.x clients - session.getProtocolHandler().sendConnack(MQTTReasonCodes.IDENTIFIER_REJECTED_3); - disconnect(true); - return; - } - String clientId = clientIdValidation.getA(); - boolean assignedClientId = clientIdValidation.getB(); - + String clientId = session.getConnection().getClientID(); boolean sessionPresent = session.getProtocolManager().getSessionStates().containsKey(clientId); MQTTSessionState sessionState = getSessionState(clientId); synchronized (sessionState) { session.setSessionState(sessionState); - session.getConnection().setClientID(clientId); sessionState.setFailed(false); ServerSessionImpl serverSession = createServerSession(username, password, validatedUser); serverSession.start(); @@ -143,7 +122,7 @@ public class MQTTConnectionManager { sessionState.setClientMaxPacketSize(MQTTUtil.getProperty(Integer.class, connect.variableHeader().properties(), MAXIMUM_PACKET_SIZE, 0)); sessionState.setClientTopicAliasMaximum(MQTTUtil.getProperty(Integer.class, connect.variableHeader().properties(), TOPIC_ALIAS_MAXIMUM)); - connackProperties = getConnackProperties(clientId, assignedClientId); + connackProperties = getConnackProperties(); } else { connackProperties = MqttProperties.NO_PROPERTIES; } @@ -155,11 +134,11 @@ public class MQTTConnectionManager { } } - private MqttProperties getConnackProperties(String clientId, boolean assignedClientId) { + private MqttProperties getConnackProperties() { MqttProperties connackProperties = new MqttProperties(); - if (assignedClientId) { - connackProperties.add(new MqttProperties.StringProperty(ASSIGNED_CLIENT_IDENTIFIER.value(), clientId)); + if (this.session.getConnection().isClientIdAssignedByBroker()) { + connackProperties.add(new MqttProperties.StringProperty(ASSIGNED_CLIENT_IDENTIFIER.value(), this.session.getConnection().getClientID())); } if (this.session.getProtocolManager().getTopicAliasMaximum() != -1) { @@ -227,30 +206,4 @@ public class MQTTConnectionManager { private synchronized MQTTSessionState getSessionState(String clientId) { return session.getProtocolManager().getSessionState(clientId); } - - private Pair validateClientId(String clientId, boolean cleanSession) { - Boolean assigned = Boolean.FALSE; - if (clientId == null || clientId.isEmpty()) { - // [MQTT-3.1.3-7] [MQTT-3.1.3-6] If client does not specify a client ID and clean session is set to 1 create it. - if (cleanSession) { - assigned = Boolean.TRUE; - clientId = UUID.randomUUID().toString(); - } else { - // [MQTT-3.1.3-8] Return ID rejected and disconnect if clean session = false and client id is null - return null; - } - } else { - MQTTConnection connection = session.getProtocolManager().addConnectedClient(clientId, session.getConnection()); - - if (connection != null) { - MQTTSession existingSession = session.getProtocolManager().getSessionState(clientId).getSession(); - if (session.getVersion() == MQTTVersion.MQTT_5) { - existingSession.getProtocolHandler().sendDisconnect(MQTTReasonCodes.SESSION_TAKEN_OVER); - } - // [MQTT-3.1.4-2] If the client ID represents a client already connected to the server then the server MUST disconnect the existing client - existingSession.getConnectionManager().disconnect(false); - } - } - return new Pair<>(clientId, assigned); - } } diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java index 67a3cce4c1..a5f74c1ac7 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java @@ -17,6 +17,8 @@ package org.apache.activemq.artemis.core.protocol.mqtt; +import java.util.UUID; + import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.codec.mqtt.MqttConnectMessage; @@ -42,6 +44,7 @@ import io.netty.handler.codec.mqtt.MqttUnsubscribeMessage; import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import org.apache.activemq.artemis.api.core.ActiveMQSecurityException; +import org.apache.activemq.artemis.api.core.Pair; import org.apache.activemq.artemis.core.protocol.mqtt.exceptions.DisconnectException; import org.apache.activemq.artemis.core.server.ActiveMQServer; import org.apache.activemq.artemis.logs.AuditLogger; @@ -224,59 +227,32 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter { void handleConnect(MqttConnectMessage connect) throws Exception { session.setVersion(MQTTVersion.getVersion(connect.variableHeader().version())); + if (!checkClientVersion()) { + return; + } + + session.getConnection().setClientID(connect.payload().clientIdentifier()); + if (!validateClientID(connect.variableHeader().isCleanSession())) { + return; + } + /* * Perform authentication *before* attempting redirection because redirection may be based on the user's role. */ String password = connect.payload().passwordInBytes() == null ? null : new String(connect.payload().passwordInBytes(), CharsetUtil.UTF_8); String username = connect.payload().userName(); - String validatedUser; - try { - validatedUser = session.getServer().validateUser(username, password, session.getConnection(), session.getProtocolManager().getSecurityDomain()); - } catch (ActiveMQSecurityException e) { - if (session.getVersion() == MQTTVersion.MQTT_5) { - session.getProtocolHandler().sendConnack(MQTTReasonCodes.BAD_USER_NAME_OR_PASSWORD); - } else { - session.getProtocolHandler().sendConnack(MQTTReasonCodes.NOT_AUTHORIZED_3); - } - disconnect(true); + Pair validationData = validateUser(username, password); + if (!validationData.getA()) { return; } + MQTTConnection existingConnection = session.getProtocolManager().addConnectedClient(session.getConnection().getClientID(), session.getConnection()); + disconnectExistingSession(existingConnection); + if (connection.getTransportConnection().getRouter() == null || !protocolManager.getRoutingHandler().route(connection, session, connect)) { - /* [MQTT-3.1.2-2] Reject unsupported clients. */ - if (session.getVersion() != MQTTVersion.MQTT_3_1 && - session.getVersion() != MQTTVersion.MQTT_3_1_1 && - session.getVersion() != MQTTVersion.MQTT_5) { + calculateKeepAlive(connect); - if (session.getVersion().getVersion() <= MQTTVersion.MQTT_3_1_1.getVersion()) { - // See MQTT-3.1.2-2 at http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718030 - sendConnack(MQTTReasonCodes.UNACCEPTABLE_PROTOCOL_VERSION_3); - } else { - // See MQTT-3.1.2-2 at https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901037 - sendConnack(MQTTReasonCodes.UNSUPPORTED_PROTOCOL_VERSION); - } - - disconnect(true); - return; - } - - /* - * If the server's keep-alive has been disabled (-1) or if the client is using a lower value than the server - * then we use the client's keep-alive. - * - * We must adjust the keep-alive because MQTT communicates keep-alive values in *seconds*, but the broker uses - * *milliseconds*. Also, the connection keep-alive is effectively "one and a half times" the configured - * keep-alive value. See [MQTT-3.1.2-22]. - */ - int serverKeepAlive = session.getProtocolManager().getServerKeepAlive(); - int clientKeepAlive = connect.variableHeader().keepAliveTimeSeconds(); - if (serverKeepAlive == -1 || (clientKeepAlive <= serverKeepAlive && clientKeepAlive != 0)) { - connectionEntry.ttl = clientKeepAlive * MQTTUtil.KEEP_ALIVE_ADJUSTMENT; - } else { - session.setUsingServerKeepAlive(true); - } - - session.getConnectionManager().connect(connect, validatedUser); + session.getConnectionManager().connect(connect, validationData.getB(), username, password); } } @@ -429,4 +405,93 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter { ActiveMQServer getServer() { return server; } + + /* + * If the server's keep-alive has been disabled (-1) or if the client is using a lower value than the server + * then we use the client's keep-alive. + * + * We must adjust the keep-alive because MQTT communicates keep-alive values in *seconds*, but the broker uses + * *milliseconds*. Also, the connection keep-alive is effectively "one and a half times" the configured + * keep-alive value. See [MQTT-3.1.2-22]. + */ + private void calculateKeepAlive(MqttConnectMessage connect) { + int serverKeepAlive = session.getProtocolManager().getServerKeepAlive(); + int clientKeepAlive = connect.variableHeader().keepAliveTimeSeconds(); + if (serverKeepAlive == -1 || (clientKeepAlive <= serverKeepAlive && clientKeepAlive != 0)) { + connectionEntry.ttl = clientKeepAlive * MQTTUtil.KEEP_ALIVE_ADJUSTMENT; + } else { + session.setUsingServerKeepAlive(true); + } + } + + // [MQTT-3.1.2-2] Reject unsupported clients. + private boolean checkClientVersion() { + if (session.getVersion() != MQTTVersion.MQTT_3_1 && + session.getVersion() != MQTTVersion.MQTT_3_1_1 && + session.getVersion() != MQTTVersion.MQTT_5) { + + if (session.getVersion().getVersion() <= MQTTVersion.MQTT_3_1_1.getVersion()) { + // See MQTT-3.1.2-2 at http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718030 + sendConnack(MQTTReasonCodes.UNACCEPTABLE_PROTOCOL_VERSION_3); + } else { + // See MQTT-3.1.2-2 at https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901037 + sendConnack(MQTTReasonCodes.UNSUPPORTED_PROTOCOL_VERSION); + } + + disconnect(true); + return false; + } + return true; + } + + // [MQTT-3.1.4-2] If the client ID represents a client already connected to the server then the server MUST disconnect the existing client + private void disconnectExistingSession(MQTTConnection existingConnection) { + if (existingConnection != null) { + MQTTSession existingSession = session.getProtocolManager().getSessionState(session.getConnection().getClientID()).getSession(); + if (session.getVersion() == MQTTVersion.MQTT_5) { + existingSession.getProtocolHandler().sendDisconnect(MQTTReasonCodes.SESSION_TAKEN_OVER); + } + existingSession.getConnectionManager().disconnect(false); + } + } + + private Pair validateUser(String username, String password) throws Exception { + String validatedUser = null; + Boolean result; + + try { + validatedUser = server.validateUser(username, password, session.getConnection(), session.getProtocolManager().getSecurityDomain()); + result = Boolean.TRUE; + } catch (ActiveMQSecurityException e) { + if (session.getVersion() == MQTTVersion.MQTT_5) { + session.getProtocolHandler().sendConnack(MQTTReasonCodes.BAD_USER_NAME_OR_PASSWORD); + } else { + session.getProtocolHandler().sendConnack(MQTTReasonCodes.NOT_AUTHORIZED_3); + } + disconnect(true); + result = Boolean.FALSE; + } + + return new Pair<>(result, validatedUser); + } + + private boolean validateClientID(boolean isCleanSession) { + if (session.getConnection().getClientID() == null || session.getConnection().getClientID().isEmpty()) { + // [MQTT-3.1.3-7] [MQTT-3.1.3-6] If client does not specify a client ID and clean session is set to 1 create it. + if (isCleanSession) { + session.getConnection().setClientID(UUID.randomUUID().toString()); + session.getConnection().setClientIdAssignedByBroker(true); + } else { + // [MQTT-3.1.3-8] Return ID rejected and disconnect if clean session = false and client id is null + if (session.getVersion() == MQTTVersion.MQTT_5) { + session.getProtocolHandler().sendConnack(MQTTReasonCodes.CLIENT_IDENTIFIER_NOT_VALID); + } else { + session.getProtocolHandler().sendConnack(MQTTReasonCodes.IDENTIFIER_REJECTED_3); + } + disconnect(true); + return false; + } + } + return true; + } } diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolManager.java index 18090f146e..c48a9802eb 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolManager.java @@ -202,7 +202,7 @@ public class MQTTProtocolManager extends AbstractProtocolManager + * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.artemis.tests.integration.mqtt; + +import javax.security.auth.Subject; +import java.util.Map; +import java.util.Set; + +import org.apache.activemq.artemis.core.protocol.mqtt.MQTTProtocolManager; +import org.apache.activemq.artemis.core.protocol.mqtt.MQTTSessionState; +import org.apache.activemq.artemis.core.remoting.impl.AbstractAcceptor; +import org.apache.activemq.artemis.core.security.CheckType; +import org.apache.activemq.artemis.core.security.Role; +import org.apache.activemq.artemis.spi.core.protocol.ProtocolManager; +import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection; +import org.apache.activemq.artemis.spi.core.remoting.Acceptor; +import org.apache.activemq.artemis.spi.core.security.ActiveMQSecurityManager5; +import org.apache.activemq.artemis.tests.util.RandomUtil; +import org.apache.activemq.artemis.tests.util.Wait; +import org.fusesource.mqtt.client.BlockingConnection; +import org.fusesource.mqtt.client.MQTT; +import org.junit.Test; + +public class MQTTSecurityManagerTest extends MQTTTestSupport { + + private String clientID = "new-" + RandomUtil.randomString(); + + @Override + public boolean isSecurityEnabled() { + return true; + } + + @Override + public void configureBroker() throws Exception { + super.configureBroker(); + server.setSecurityManager(new ActiveMQSecurityManager5() { + @Override + public Subject authenticate(String user, + String password, + RemotingConnection remotingConnection, + String securityDomain) { + remotingConnection.setClientID(clientID); + System.out.println("Setting: " + clientID); + return new Subject(); + } + + @Override + public boolean authorize(Subject subject, Set roles, CheckType checkType, String address) { + return true; + } + + @Override + public boolean validateUser(String user, String password) { + return true; + } + + @Override + public boolean validateUserAndRole(String user, String password, Set roles, CheckType checkType) { + return true; + } + }); + server.getConfiguration().setAuthenticationCacheSize(0); + server.getConfiguration().setAuthorizationCacheSize(0); + } + + @Test(timeout = 30000) + public void testSecurityManagerModifyClientID() throws Exception { + BlockingConnection connection = null; + try { + MQTT mqtt = createMQTTConnection(RandomUtil.randomString(), true); + mqtt.setUserName(fullUser); + mqtt.setPassword(fullPass); + mqtt.setConnectAttemptsMax(1); + connection = mqtt.blockingConnection(); + connection.connect(); + BlockingConnection finalConnection = connection; + assertTrue("Should be connected", Wait.waitFor(() -> finalConnection.isConnected(), 5000, 100)); + Map sessionStates = null; + Acceptor acceptor = server.getRemotingService().getAcceptor("MQTT"); + if (acceptor instanceof AbstractAcceptor) { + ProtocolManager protocolManager = ((AbstractAcceptor) acceptor).getProtocolMap().get("MQTT"); + if (protocolManager instanceof MQTTProtocolManager) { + sessionStates = ((MQTTProtocolManager) protocolManager).getSessionStates(); + } + } + assertEquals(1, sessionStates.size()); + assertTrue(sessionStates.keySet().contains(clientID)); + for (MQTTSessionState state : sessionStates.values()) { + assertEquals(clientID, state.getClientId()); + } + } finally { + if (connection != null && connection.isConnected()) connection.disconnect(); + } + } + + @Test(timeout = 30000) + public void testSecurityManagerModifyClientIDAndStealConnection() throws Exception { + BlockingConnection connection1 = null; + BlockingConnection connection2 = null; + final String CLIENT_ID = "old-" + RandomUtil.randomString(); + try { + MQTT mqtt = createMQTTConnection(CLIENT_ID, true); + mqtt.setUserName(fullUser); + mqtt.setPassword(fullPass); + mqtt.setConnectAttemptsMax(1); + connection1 = mqtt.blockingConnection(); + connection1.connect(); + final BlockingConnection finalConnection = connection1; + assertTrue("Should be connected", Wait.waitFor(() -> finalConnection.isConnected(), 5000, 100)); + Map sessionStates = null; + Acceptor acceptor = server.getRemotingService().getAcceptor("MQTT"); + if (acceptor instanceof AbstractAcceptor) { + ProtocolManager protocolManager = ((AbstractAcceptor) acceptor).getProtocolMap().get("MQTT"); + if (protocolManager instanceof MQTTProtocolManager) { + sessionStates = ((MQTTProtocolManager) protocolManager).getSessionStates(); + } + } + assertEquals(1, sessionStates.size()); + assertTrue(sessionStates.keySet().contains(clientID)); + for (MQTTSessionState state : sessionStates.values()) { + assertEquals(clientID, state.getClientId()); + } + + connection2 = mqtt.blockingConnection(); + connection2.connect(); + final BlockingConnection finalConnection2 = connection2; + assertTrue("Should be connected", Wait.waitFor(() -> finalConnection2.isConnected(), 5000, 100)); + Wait.assertFalse(() -> finalConnection.isConnected(), 5000, 100); + assertEquals(1, sessionStates.size()); + assertTrue(sessionStates.keySet().contains(clientID)); + for (MQTTSessionState state : sessionStates.values()) { + assertEquals(clientID, state.getClientId()); + } + } finally { + if (connection1 != null && connection1.isConnected()) connection1.disconnect(); + } + } +}