ARTEMIS-826 Fix MQTT protocol detection

This commit is contained in:
Martyn Taylor 2017-04-24 16:27:46 +01:00
parent d0219bea18
commit 1c84bd39c4
3 changed files with 59 additions and 11 deletions

View File

@ -20,6 +20,8 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.mqtt.MqttDecoder; import io.netty.handler.codec.mqtt.MqttDecoder;
import io.netty.handler.codec.mqtt.MqttEncoder; import io.netty.handler.codec.mqtt.MqttEncoder;
@ -115,19 +117,43 @@ class MQTTProtocolManager extends AbstractProtocolManager<MqttMessage, MQTTInter
pipeline.addLast(new MQTTProtocolHandler(server, this)); pipeline.addLast(new MQTTProtocolHandler(server, this));
} }
/**
* The protocol handler passes us an 8 byte long array from the transport. We sniff these first 8 bytes to see
* if they match the first 8 bytes from MQTT Connect packet. In many other protocols the protocol name is the first
* thing sent on the wire. However, in MQTT the protocol name doesn't come until later on in the CONNECT packet.
*
* In order to fully identify MQTT protocol via protocol name, we need up to 12 bytes. However, we can use other
* information from the connect packet to infer that the MQTT protocol is being used. This is enough to identify MQTT
* and add the Netty codec in the pipeline. The Netty codec takes care of things from here.
*
* MQTT CONNECT PACKET: See MQTT 3.1.1 Spec for more info.
*
* Byte 1: Fixed Header Packet Type. 0b0001000 (16) = MQTT Connect
* Byte 2-[N]: Remaining length of the Connect Packet (encoded with 1-4 bytes).
*
* The next set of bytes represents the UTF8 encoded string MQTT (MQTT 3.1.1) or MQIsdp (MQTT 3.1)
* Byte N: UTF8 MSB must be 0
* Byte N+1: UTF8 LSB must be (4(MQTT) or 6(MQIsdp))
* Byte N+1: M (first char from the protocol name).
*
* Max no bytes used in the sequence = 8.
*/
@Override @Override
public boolean isProtocol(byte[] array) { public boolean isProtocol(byte[] array) {
boolean mqtt311 = array[4] == 77 && // M ByteBuf buf = Unpooled.wrappedBuffer(array);
array[5] == 81 && // Q
array[6] == 84 && // T
array[7] == 84; // T
// FIXME The actual protocol name is 'MQIsdp' (However we are only passed the first 4 bytes of the protocol name) if (!(buf.readByte() == 16 && validateRemainingLength(buf) && buf.readByte() == (byte) 0)) return false;
boolean mqtt31 = array[4] == 77 && // M byte b = buf.readByte();
array[5] == 81 && // Q return ((b == 4 || b == 6) && (buf.readByte() == 77));
array[6] == 73 && // I }
array[7] == 115; // s
return mqtt311 || mqtt31; private boolean validateRemainingLength(ByteBuf buffer) {
byte msb = (byte) 0b10000000;
for (byte i = 0; i < 4; i++) {
if ((buffer.readByte() & msb) != msb)
return true;
}
return false;
} }
@Override @Override

View File

@ -135,7 +135,7 @@ public class ProtocolHandler {
return; return;
} }
// Will use the first five bytes to detect a protocol. // Will use the first N bytes to detect a protocol depending on the protocol.
if (in.readableBytes() < 8) { if (in.readableBytes() < 8) {
return; return;
} }
@ -175,6 +175,7 @@ public class ProtocolHandler {
protocolToUse = ActiveMQClient.DEFAULT_CORE_PROTOCOL; protocolToUse = ActiveMQClient.DEFAULT_CORE_PROTOCOL;
} }
} }
ProtocolManager protocolManagerToUse = protocolMap.get(protocolToUse); ProtocolManager protocolManagerToUse = protocolMap.get(protocolToUse);
ConnectionCreator channelHandler = nettyAcceptor.createConnectionCreator(); ConnectionCreator channelHandler = nettyAcceptor.createConnectionCreator();
ChannelPipeline pipeline = ctx.pipeline(); ChannelPipeline pipeline = ctx.pipeline();

View File

@ -83,6 +83,27 @@ public class MQTTTest extends MQTTTestSupport {
} }
@Test
public void testConnectWithLargePassword() throws Exception {
for (String version : Arrays.asList("3.1", "3.1.1")) {
String longString = new String(new char[65535]);
BlockingConnection connection = null;
try {
MQTT mqtt = createMQTTConnection("test-" + version, true);
mqtt.setUserName(longString);
mqtt.setPassword(longString);
mqtt.setConnectAttemptsMax(1);
mqtt.setVersion(version);
connection = mqtt.blockingConnection();
connection.connect();
assertTrue(connection.isConnected());
} finally {
if (connection != null && connection.isConnected()) connection.disconnect();
}
}
}
@Test(timeout = 60 * 1000) @Test(timeout = 60 * 1000)
public void testSendAndReceiveMQTT() throws Exception { public void testSendAndReceiveMQTT() throws Exception {
final MQTTClientProvider subscriptionProvider = getMQTTClientProvider(); final MQTTClientProvider subscriptionProvider = getMQTTClientProvider();