StompSocket and MQTTSocket will now return the appropriate web socket
remote address based on the HttpRequestServlet that initialized
the web socket connection.
This commit is contained in:
Christopher L. Shannon (cshannon) 2015-06-29 12:35:17 +00:00
parent bbf288b12c
commit be10b866a7
12 changed files with 217 additions and 139 deletions

View File

@ -115,6 +115,11 @@
<scope>provided</scope> <scope>provided</scope>
<optional>true</optional> <optional>true</optional>
</dependency> </dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<build> <build>
<plugins> <plugins>

View File

@ -0,0 +1,20 @@
package org.apache.activemq.transport.util;
import javax.servlet.http.HttpServletRequest;
public class HttpTransportUtils {
public static String generateWsRemoteAddress(HttpServletRequest request) {
if (request == null) {
throw new IllegalArgumentException("HttpServletRequest must not be null.");
}
StringBuilder remoteAddress = new StringBuilder();
String scheme = request.getScheme();
remoteAddress.append(scheme != null && scheme.toLowerCase().equals("https") ? "wss://" : "ws://");
remoteAddress.append(request.getRemoteAddr());
remoteAddress.append(":");
remoteAddress.append(request.getRemotePort());
return remoteAddress.toString();
}
}

View File

@ -0,0 +1,76 @@
package org.apache.activemq.transport.ws;
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.broker.BrokerServiceAware;
import org.apache.activemq.transport.TransportSupport;
import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor;
import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
import org.apache.activemq.transport.mqtt.MQTTTransport;
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
import org.apache.activemq.util.ServiceStopper;
public abstract class AbstractMQTTSocket extends TransportSupport implements MQTTTransport, BrokerServiceAware {
protected MQTTWireFormat wireFormat = new MQTTWireFormat();
protected final CountDownLatch socketTransportStarted = new CountDownLatch(1);
protected MQTTProtocolConverter protocolConverter = null;
private BrokerService brokerService;
protected final String remoteAddress;
public AbstractMQTTSocket(String remoteAddress) {
super();
this.remoteAddress = remoteAddress;
}
protected boolean transportStartedAtLeastOnce() {
return socketTransportStarted.getCount() == 0;
}
protected void doStart() throws Exception {
socketTransportStarted.countDown();
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
}
protected MQTTProtocolConverter getProtocolConverter() {
if( protocolConverter == null ) {
protocolConverter = new MQTTProtocolConverter(this, brokerService);
}
return protocolConverter;
}
@Override
public int getReceiveCounter() {
return 0;
}
@Override
public X509Certificate[] getPeerCertificates() {
return new X509Certificate[0];
}
@Override
public MQTTInactivityMonitor getInactivityMonitor() {
return null;
}
@Override
public MQTTWireFormat getWireFormat() {
return wireFormat;
}
@Override
public String getRemoteAddress() {
return remoteAddress;
}
@Override
public void setBrokerService(BrokerService brokerService) {
this.brokerService = brokerService;
}
}

View File

@ -45,6 +45,13 @@ public abstract class AbstractStompSocket extends TransportSupport implements St
protected final CountDownLatch socketTransportStarted = new CountDownLatch(1); protected final CountDownLatch socketTransportStarted = new CountDownLatch(1);
protected final StompInactivityMonitor stompInactivityMonitor = new StompInactivityMonitor(this, wireFormat); protected final StompInactivityMonitor stompInactivityMonitor = new StompInactivityMonitor(this, wireFormat);
protected volatile int receiveCounter; protected volatile int receiveCounter;
protected final String remoteAddress;
public AbstractStompSocket(String remoteAddress) {
super();
this.remoteAddress = remoteAddress;
}
@Override @Override
public void oneway(Object command) throws IOException { public void oneway(Object command) throws IOException {
@ -100,7 +107,7 @@ public abstract class AbstractStompSocket extends TransportSupport implements St
@Override @Override
public String getRemoteAddress() { public String getRemoteAddress() {
return "StompSocket_" + this.hashCode(); return remoteAddress;
} }
@Override @Override

View File

@ -16,35 +16,26 @@
*/ */
package org.apache.activemq.transport.ws.jetty8; package org.apache.activemq.transport.ws.jetty8;
import org.apache.activemq.broker.BrokerService; import java.io.IOException;
import org.apache.activemq.broker.BrokerServiceAware;
import org.apache.activemq.command.Command; import org.apache.activemq.command.Command;
import org.apache.activemq.transport.TransportSupport; import org.apache.activemq.transport.ws.AbstractMQTTSocket;
import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor;
import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
import org.apache.activemq.transport.mqtt.MQTTTransport;
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
import org.apache.activemq.util.ByteSequence; import org.apache.activemq.util.ByteSequence;
import org.apache.activemq.util.IOExceptionSupport; import org.apache.activemq.util.IOExceptionSupport;
import org.apache.activemq.util.ServiceStopper;
import org.eclipse.jetty.websocket.WebSocket; import org.eclipse.jetty.websocket.WebSocket;
import org.fusesource.mqtt.codec.DISCONNECT; import org.fusesource.mqtt.codec.DISCONNECT;
import org.fusesource.mqtt.codec.MQTTFrame; import org.fusesource.mqtt.codec.MQTTFrame;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.IOException; public class MQTTSocket extends AbstractMQTTSocket implements WebSocket.OnBinaryMessage {
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;
public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryMessage, MQTTTransport, BrokerServiceAware {
private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class); private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class);
Connection outbound; Connection outbound;
MQTTProtocolConverter protocolConverter = null;
MQTTWireFormat wireFormat = new MQTTWireFormat(); public MQTTSocket(String remoteAddress) {
private final CountDownLatch socketTransportStarted = new CountDownLatch(1); super(remoteAddress);
private BrokerService brokerService; }
@Override @Override
public void onMessage(byte[] bytes, int offset, int length) { public void onMessage(byte[] bytes, int offset, int length) {
@ -65,12 +56,6 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM
} }
} }
private MQTTProtocolConverter getProtocolConverter() {
if( protocolConverter == null ) {
protocolConverter = new MQTTProtocolConverter(this, brokerService);
}
return protocolConverter;
}
@Override @Override
public void onOpen(Connection connection) { public void onOpen(Connection connection) {
@ -86,28 +71,6 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM
} }
} }
protected void doStart() throws Exception {
socketTransportStarted.countDown();
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
}
private boolean transportStartedAtLeastOnce() {
return socketTransportStarted.getCount() == 0;
}
@Override
public int getReceiveCounter() {
return 0;
}
@Override
public String getRemoteAddress() {
return "MQTTSocket_" + this.hashCode();
}
@Override @Override
public void oneway(Object command) throws IOException { public void oneway(Object command) throws IOException {
try { try {
@ -128,23 +91,4 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM
outbound.sendMessage(bytes.getData(), 0, bytes.getLength()); outbound.sendMessage(bytes.getData(), 0, bytes.getLength());
} }
@Override
public X509Certificate[] getPeerCertificates() {
return new X509Certificate[0];
}
@Override
public MQTTInactivityMonitor getInactivityMonitor() {
return null;
}
@Override
public MQTTWireFormat getWireFormat() {
return wireFormat;
}
@Override
public void setBrokerService(BrokerService brokerService) {
this.brokerService = brokerService;
}
} }

View File

@ -28,12 +28,16 @@ import org.slf4j.LoggerFactory;
/** /**
* Implements web socket and mediates between servlet and the broker * Implements web socket and mediates between servlet and the broker
*/ */
class StompSocket extends AbstractStompSocket implements WebSocket.OnTextMessage { public class StompSocket extends AbstractStompSocket implements WebSocket.OnTextMessage {
private static final Logger LOG = LoggerFactory.getLogger(StompSocket.class); private static final Logger LOG = LoggerFactory.getLogger(StompSocket.class);
private Connection outbound; private Connection outbound;
public StompSocket(String remoteAddress) {
super(remoteAddress);
}
@Override @Override
public void handleStopped() throws IOException { public void handleStopped() throws IOException {
if (outbound != null && outbound.isOpen()) { if (outbound != null && outbound.isOpen()) {

View File

@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.activemq.transport.Transport; import org.apache.activemq.transport.Transport;
import org.apache.activemq.transport.TransportAcceptListener; import org.apache.activemq.transport.TransportAcceptListener;
import org.apache.activemq.transport.util.HttpTransportUtils;
import org.eclipse.jetty.websocket.WebSocket; import org.eclipse.jetty.websocket.WebSocket;
import org.eclipse.jetty.websocket.WebSocketServlet; import org.eclipse.jetty.websocket.WebSocketServlet;
@ -54,9 +55,9 @@ public class WSServlet extends WebSocketServlet {
public WebSocket doWebSocketConnect(HttpServletRequest request, String protocol) { public WebSocket doWebSocketConnect(HttpServletRequest request, String protocol) {
WebSocket socket; WebSocket socket;
if (protocol != null && protocol.startsWith("mqtt")) { if (protocol != null && protocol.startsWith("mqtt")) {
socket = new MQTTSocket(); socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(request));
} else { } else {
socket = new StompSocket(); socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(request));
} }
listener.onAccept((Transport) socket); listener.onAccept((Transport) socket);
return socket; return socket;

View File

@ -16,17 +16,13 @@
*/ */
package org.apache.activemq.transport.ws.jetty9; package org.apache.activemq.transport.ws.jetty9;
import org.apache.activemq.broker.BrokerService; import java.io.IOException;
import org.apache.activemq.broker.BrokerServiceAware; import java.nio.ByteBuffer;
import org.apache.activemq.command.Command; import org.apache.activemq.command.Command;
import org.apache.activemq.transport.TransportSupport; import org.apache.activemq.transport.ws.AbstractMQTTSocket;
import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor;
import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
import org.apache.activemq.transport.mqtt.MQTTTransport;
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
import org.apache.activemq.util.ByteSequence; import org.apache.activemq.util.ByteSequence;
import org.apache.activemq.util.IOExceptionSupport; import org.apache.activemq.util.IOExceptionSupport;
import org.apache.activemq.util.ServiceStopper;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener; import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.fusesource.mqtt.codec.DISCONNECT; import org.fusesource.mqtt.codec.DISCONNECT;
@ -34,47 +30,13 @@ import org.fusesource.mqtt.codec.MQTTFrame;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.IOException; public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener {
import java.nio.ByteBuffer;
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;
public class MQTTSocket extends TransportSupport implements WebSocketListener, MQTTTransport, BrokerServiceAware {
private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class); private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class);
Session session; Session session;
MQTTProtocolConverter protocolConverter = null;
MQTTWireFormat wireFormat = new MQTTWireFormat();
private final CountDownLatch socketTransportStarted = new CountDownLatch(1);
private BrokerService brokerService;
private MQTTProtocolConverter getProtocolConverter() { public MQTTSocket(String remoteAddress) {
if( protocolConverter == null ) { super(remoteAddress);
protocolConverter = new MQTTProtocolConverter(this, brokerService);
}
return protocolConverter;
}
protected void doStart() throws Exception {
socketTransportStarted.countDown();
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
}
private boolean transportStartedAtLeastOnce() {
return socketTransportStarted.getCount() == 0;
}
@Override
public int getReceiveCounter() {
return 0;
}
@Override
public String getRemoteAddress() {
return "MQTTSocket_" + this.hashCode();
} }
@Override @Override
@ -97,26 +59,6 @@ public class MQTTSocket extends TransportSupport implements WebSocketListener,
session.getRemote().sendBytes(ByteBuffer.wrap(bytes.getData(), 0, bytes.getLength())); session.getRemote().sendBytes(ByteBuffer.wrap(bytes.getData(), 0, bytes.getLength()));
} }
@Override
public X509Certificate[] getPeerCertificates() {
return new X509Certificate[0];
}
@Override
public MQTTInactivityMonitor getInactivityMonitor() {
return null;
}
@Override
public MQTTWireFormat getWireFormat() {
return wireFormat;
}
@Override
public void setBrokerService(BrokerService brokerService) {
this.brokerService = brokerService;
}
@Override @Override
public void onWebSocketBinary(byte[] bytes, int offset, int length) { public void onWebSocketBinary(byte[] bytes, int offset, int length) {
if (!transportStartedAtLeastOnce()) { if (!transportStartedAtLeastOnce()) {

View File

@ -29,12 +29,16 @@ import org.slf4j.LoggerFactory;
/** /**
* Implements web socket and mediates between servlet and the broker * Implements web socket and mediates between servlet and the broker
*/ */
class StompSocket extends AbstractStompSocket implements WebSocketListener { public class StompSocket extends AbstractStompSocket implements WebSocketListener {
private static final Logger LOG = LoggerFactory.getLogger(StompSocket.class); private static final Logger LOG = LoggerFactory.getLogger(StompSocket.class);
private Session session; private Session session;
public StompSocket(String remoteAddress) {
super(remoteAddress);
}
@Override @Override
public void sendToStomp(StompFrame command) throws IOException { public void sendToStomp(StompFrame command) throws IOException {
session.getRemote().sendString(command.format()); session.getRemote().sendString(command.format());

View File

@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.activemq.transport.Transport; import org.apache.activemq.transport.Transport;
import org.apache.activemq.transport.TransportAcceptListener; import org.apache.activemq.transport.TransportAcceptListener;
import org.apache.activemq.transport.util.HttpTransportUtils;
import org.eclipse.jetty.websocket.api.WebSocketListener; import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
@ -62,10 +63,10 @@ public class WSServlet extends WebSocketServlet {
public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) { public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) {
WebSocketListener socket; WebSocketListener socket;
if (req.getSubProtocols().contains("mqtt")) { if (req.getSubProtocols().contains("mqtt")) {
socket = new MQTTSocket(); socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
resp.setAcceptedSubProtocol("mqtt"); resp.setAcceptedSubProtocol("mqtt");
} else { } else {
socket = new StompSocket(); socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
resp.setAcceptedSubProtocol("stomp"); resp.setAcceptedSubProtocol("stomp");
} }
listener.onAccept((Transport) socket); listener.onAccept((Transport) socket);

View File

@ -0,0 +1,37 @@
package org.apache.activemq.transport.util;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import javax.servlet.http.HttpServletRequest;
import org.junit.Test;
public class HttpTransportUtilsTest {
@Test
public void testGenerateWsRemoteAddress() {
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getScheme()).thenReturn("http");
when(request.getRemoteAddr()).thenReturn("localhost");
when(request.getRemotePort()).thenReturn(8080);
assertEquals("ws://localhost:8080", HttpTransportUtils.generateWsRemoteAddress(request));
}
@Test
public void testGenerateWssRemoteAddress() {
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getScheme()).thenReturn("https");
when(request.getRemoteAddr()).thenReturn("localhost");
when(request.getRemotePort()).thenReturn(8443);
assertEquals("wss://localhost:8443", HttpTransportUtils.generateWsRemoteAddress(request));
}
@Test(expected=IllegalArgumentException.class)
public void testNullHttpServleRequest() {
HttpTransportUtils.generateWsRemoteAddress(null);
}
}

View File

@ -0,0 +1,37 @@
package org.apache.activemq.transport.ws;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
public class SocketTest {
@Test
public void testStompSocketRemoteAddress() {
org.apache.activemq.transport.ws.jetty8.StompSocket stompSocketJetty8 =
new org.apache.activemq.transport.ws.jetty8.StompSocket("ws://localhost:8080");
assertEquals("ws://localhost:8080", stompSocketJetty8.getRemoteAddress());
org.apache.activemq.transport.ws.jetty9.StompSocket stompSocketJetty9 =
new org.apache.activemq.transport.ws.jetty9.StompSocket("ws://localhost:8080");
assertEquals("ws://localhost:8080", stompSocketJetty9.getRemoteAddress());
}
@Test
public void testMqttSocketRemoteAddress() {
org.apache.activemq.transport.ws.jetty8.MQTTSocket mqttSocketJetty8 =
new org.apache.activemq.transport.ws.jetty8.MQTTSocket("ws://localhost:8080");
assertEquals("ws://localhost:8080", mqttSocketJetty8.getRemoteAddress());
org.apache.activemq.transport.ws.jetty8.MQTTSocket mqttSocketJetty9 =
new org.apache.activemq.transport.ws.jetty8.MQTTSocket("ws://localhost:8080");
assertEquals("ws://localhost:8080", mqttSocketJetty9.getRemoteAddress());
}
}