StompSocket and MQTTSocket will now return the appropriate web socket
remote address based on the HttpRequestServlet that initialized
the web socket connection.
This commit is contained in:
Christopher L. Shannon (cshannon) 2015-06-29 12:35:17 +00:00
parent bbf288b12c
commit be10b866a7
12 changed files with 217 additions and 139 deletions

View File

@ -115,6 +115,11 @@
<scope>provided</scope>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>

View File

@ -0,0 +1,20 @@
package org.apache.activemq.transport.util;
import javax.servlet.http.HttpServletRequest;
public class HttpTransportUtils {
public static String generateWsRemoteAddress(HttpServletRequest request) {
if (request == null) {
throw new IllegalArgumentException("HttpServletRequest must not be null.");
}
StringBuilder remoteAddress = new StringBuilder();
String scheme = request.getScheme();
remoteAddress.append(scheme != null && scheme.toLowerCase().equals("https") ? "wss://" : "ws://");
remoteAddress.append(request.getRemoteAddr());
remoteAddress.append(":");
remoteAddress.append(request.getRemotePort());
return remoteAddress.toString();
}
}

View File

@ -0,0 +1,76 @@
package org.apache.activemq.transport.ws;
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.broker.BrokerServiceAware;
import org.apache.activemq.transport.TransportSupport;
import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor;
import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
import org.apache.activemq.transport.mqtt.MQTTTransport;
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
import org.apache.activemq.util.ServiceStopper;
public abstract class AbstractMQTTSocket extends TransportSupport implements MQTTTransport, BrokerServiceAware {
protected MQTTWireFormat wireFormat = new MQTTWireFormat();
protected final CountDownLatch socketTransportStarted = new CountDownLatch(1);
protected MQTTProtocolConverter protocolConverter = null;
private BrokerService brokerService;
protected final String remoteAddress;
public AbstractMQTTSocket(String remoteAddress) {
super();
this.remoteAddress = remoteAddress;
}
protected boolean transportStartedAtLeastOnce() {
return socketTransportStarted.getCount() == 0;
}
protected void doStart() throws Exception {
socketTransportStarted.countDown();
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
}
protected MQTTProtocolConverter getProtocolConverter() {
if( protocolConverter == null ) {
protocolConverter = new MQTTProtocolConverter(this, brokerService);
}
return protocolConverter;
}
@Override
public int getReceiveCounter() {
return 0;
}
@Override
public X509Certificate[] getPeerCertificates() {
return new X509Certificate[0];
}
@Override
public MQTTInactivityMonitor getInactivityMonitor() {
return null;
}
@Override
public MQTTWireFormat getWireFormat() {
return wireFormat;
}
@Override
public String getRemoteAddress() {
return remoteAddress;
}
@Override
public void setBrokerService(BrokerService brokerService) {
this.brokerService = brokerService;
}
}

View File

@ -45,6 +45,13 @@ public abstract class AbstractStompSocket extends TransportSupport implements St
protected final CountDownLatch socketTransportStarted = new CountDownLatch(1);
protected final StompInactivityMonitor stompInactivityMonitor = new StompInactivityMonitor(this, wireFormat);
protected volatile int receiveCounter;
protected final String remoteAddress;
public AbstractStompSocket(String remoteAddress) {
super();
this.remoteAddress = remoteAddress;
}
@Override
public void oneway(Object command) throws IOException {
@ -100,7 +107,7 @@ public abstract class AbstractStompSocket extends TransportSupport implements St
@Override
public String getRemoteAddress() {
return "StompSocket_" + this.hashCode();
return remoteAddress;
}
@Override

View File

@ -16,35 +16,26 @@
*/
package org.apache.activemq.transport.ws.jetty8;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.broker.BrokerServiceAware;
import java.io.IOException;
import org.apache.activemq.command.Command;
import org.apache.activemq.transport.TransportSupport;
import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor;
import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
import org.apache.activemq.transport.mqtt.MQTTTransport;
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
import org.apache.activemq.transport.ws.AbstractMQTTSocket;
import org.apache.activemq.util.ByteSequence;
import org.apache.activemq.util.IOExceptionSupport;
import org.apache.activemq.util.ServiceStopper;
import org.eclipse.jetty.websocket.WebSocket;
import org.fusesource.mqtt.codec.DISCONNECT;
import org.fusesource.mqtt.codec.MQTTFrame;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;
public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryMessage, MQTTTransport, BrokerServiceAware {
public class MQTTSocket extends AbstractMQTTSocket implements WebSocket.OnBinaryMessage {
private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class);
Connection outbound;
MQTTProtocolConverter protocolConverter = null;
MQTTWireFormat wireFormat = new MQTTWireFormat();
private final CountDownLatch socketTransportStarted = new CountDownLatch(1);
private BrokerService brokerService;
public MQTTSocket(String remoteAddress) {
super(remoteAddress);
}
@Override
public void onMessage(byte[] bytes, int offset, int length) {
@ -65,12 +56,6 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM
}
}
private MQTTProtocolConverter getProtocolConverter() {
if( protocolConverter == null ) {
protocolConverter = new MQTTProtocolConverter(this, brokerService);
}
return protocolConverter;
}
@Override
public void onOpen(Connection connection) {
@ -86,28 +71,6 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM
}
}
protected void doStart() throws Exception {
socketTransportStarted.countDown();
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
}
private boolean transportStartedAtLeastOnce() {
return socketTransportStarted.getCount() == 0;
}
@Override
public int getReceiveCounter() {
return 0;
}
@Override
public String getRemoteAddress() {
return "MQTTSocket_" + this.hashCode();
}
@Override
public void oneway(Object command) throws IOException {
try {
@ -128,23 +91,4 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM
outbound.sendMessage(bytes.getData(), 0, bytes.getLength());
}
@Override
public X509Certificate[] getPeerCertificates() {
return new X509Certificate[0];
}
@Override
public MQTTInactivityMonitor getInactivityMonitor() {
return null;
}
@Override
public MQTTWireFormat getWireFormat() {
return wireFormat;
}
@Override
public void setBrokerService(BrokerService brokerService) {
this.brokerService = brokerService;
}
}

View File

@ -28,12 +28,16 @@ import org.slf4j.LoggerFactory;
/**
* Implements web socket and mediates between servlet and the broker
*/
class StompSocket extends AbstractStompSocket implements WebSocket.OnTextMessage {
public class StompSocket extends AbstractStompSocket implements WebSocket.OnTextMessage {
private static final Logger LOG = LoggerFactory.getLogger(StompSocket.class);
private Connection outbound;
public StompSocket(String remoteAddress) {
super(remoteAddress);
}
@Override
public void handleStopped() throws IOException {
if (outbound != null && outbound.isOpen()) {

View File

@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.activemq.transport.Transport;
import org.apache.activemq.transport.TransportAcceptListener;
import org.apache.activemq.transport.util.HttpTransportUtils;
import org.eclipse.jetty.websocket.WebSocket;
import org.eclipse.jetty.websocket.WebSocketServlet;
@ -54,9 +55,9 @@ public class WSServlet extends WebSocketServlet {
public WebSocket doWebSocketConnect(HttpServletRequest request, String protocol) {
WebSocket socket;
if (protocol != null && protocol.startsWith("mqtt")) {
socket = new MQTTSocket();
socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(request));
} else {
socket = new StompSocket();
socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(request));
}
listener.onAccept((Transport) socket);
return socket;

View File

@ -16,17 +16,13 @@
*/
package org.apache.activemq.transport.ws.jetty9;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.broker.BrokerServiceAware;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.activemq.command.Command;
import org.apache.activemq.transport.TransportSupport;
import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor;
import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
import org.apache.activemq.transport.mqtt.MQTTTransport;
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
import org.apache.activemq.transport.ws.AbstractMQTTSocket;
import org.apache.activemq.util.ByteSequence;
import org.apache.activemq.util.IOExceptionSupport;
import org.apache.activemq.util.ServiceStopper;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.fusesource.mqtt.codec.DISCONNECT;
@ -34,47 +30,13 @@ import org.fusesource.mqtt.codec.MQTTFrame;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;
public class MQTTSocket extends TransportSupport implements WebSocketListener, MQTTTransport, BrokerServiceAware {
public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener {
private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class);
Session session;
MQTTProtocolConverter protocolConverter = null;
MQTTWireFormat wireFormat = new MQTTWireFormat();
private final CountDownLatch socketTransportStarted = new CountDownLatch(1);
private BrokerService brokerService;
private MQTTProtocolConverter getProtocolConverter() {
if( protocolConverter == null ) {
protocolConverter = new MQTTProtocolConverter(this, brokerService);
}
return protocolConverter;
}
protected void doStart() throws Exception {
socketTransportStarted.countDown();
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
}
private boolean transportStartedAtLeastOnce() {
return socketTransportStarted.getCount() == 0;
}
@Override
public int getReceiveCounter() {
return 0;
}
@Override
public String getRemoteAddress() {
return "MQTTSocket_" + this.hashCode();
public MQTTSocket(String remoteAddress) {
super(remoteAddress);
}
@Override
@ -97,26 +59,6 @@ public class MQTTSocket extends TransportSupport implements WebSocketListener,
session.getRemote().sendBytes(ByteBuffer.wrap(bytes.getData(), 0, bytes.getLength()));
}
@Override
public X509Certificate[] getPeerCertificates() {
return new X509Certificate[0];
}
@Override
public MQTTInactivityMonitor getInactivityMonitor() {
return null;
}
@Override
public MQTTWireFormat getWireFormat() {
return wireFormat;
}
@Override
public void setBrokerService(BrokerService brokerService) {
this.brokerService = brokerService;
}
@Override
public void onWebSocketBinary(byte[] bytes, int offset, int length) {
if (!transportStartedAtLeastOnce()) {
@ -142,7 +84,7 @@ public class MQTTSocket extends TransportSupport implements WebSocketListener,
getProtocolConverter().onMQTTCommand(new DISCONNECT().encode());
} catch (Exception e) {
LOG.warn("Failed to close WebSocket", e);
}
}
}
@Override
@ -152,10 +94,10 @@ public class MQTTSocket extends TransportSupport implements WebSocketListener,
@Override
public void onWebSocketError(Throwable arg0) {
}
@Override
public void onWebSocketText(String arg0) {
public void onWebSocketText(String arg0) {
}
}

View File

@ -29,12 +29,16 @@ import org.slf4j.LoggerFactory;
/**
* Implements web socket and mediates between servlet and the broker
*/
class StompSocket extends AbstractStompSocket implements WebSocketListener {
public class StompSocket extends AbstractStompSocket implements WebSocketListener {
private static final Logger LOG = LoggerFactory.getLogger(StompSocket.class);
private Session session;
public StompSocket(String remoteAddress) {
super(remoteAddress);
}
@Override
public void sendToStomp(StompFrame command) throws IOException {
session.getRemote().sendString(command.format());

View File

@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.activemq.transport.Transport;
import org.apache.activemq.transport.TransportAcceptListener;
import org.apache.activemq.transport.util.HttpTransportUtils;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
@ -62,10 +63,10 @@ public class WSServlet extends WebSocketServlet {
public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) {
WebSocketListener socket;
if (req.getSubProtocols().contains("mqtt")) {
socket = new MQTTSocket();
socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
resp.setAcceptedSubProtocol("mqtt");
} else {
socket = new StompSocket();
socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
resp.setAcceptedSubProtocol("stomp");
}
listener.onAccept((Transport) socket);

View File

@ -0,0 +1,37 @@
package org.apache.activemq.transport.util;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import javax.servlet.http.HttpServletRequest;
import org.junit.Test;
public class HttpTransportUtilsTest {
@Test
public void testGenerateWsRemoteAddress() {
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getScheme()).thenReturn("http");
when(request.getRemoteAddr()).thenReturn("localhost");
when(request.getRemotePort()).thenReturn(8080);
assertEquals("ws://localhost:8080", HttpTransportUtils.generateWsRemoteAddress(request));
}
@Test
public void testGenerateWssRemoteAddress() {
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getScheme()).thenReturn("https");
when(request.getRemoteAddr()).thenReturn("localhost");
when(request.getRemotePort()).thenReturn(8443);
assertEquals("wss://localhost:8443", HttpTransportUtils.generateWsRemoteAddress(request));
}
@Test(expected=IllegalArgumentException.class)
public void testNullHttpServleRequest() {
HttpTransportUtils.generateWsRemoteAddress(null);
}
}

View File

@ -0,0 +1,37 @@
package org.apache.activemq.transport.ws;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
public class SocketTest {
@Test
public void testStompSocketRemoteAddress() {
org.apache.activemq.transport.ws.jetty8.StompSocket stompSocketJetty8 =
new org.apache.activemq.transport.ws.jetty8.StompSocket("ws://localhost:8080");
assertEquals("ws://localhost:8080", stompSocketJetty8.getRemoteAddress());
org.apache.activemq.transport.ws.jetty9.StompSocket stompSocketJetty9 =
new org.apache.activemq.transport.ws.jetty9.StompSocket("ws://localhost:8080");
assertEquals("ws://localhost:8080", stompSocketJetty9.getRemoteAddress());
}
@Test
public void testMqttSocketRemoteAddress() {
org.apache.activemq.transport.ws.jetty8.MQTTSocket mqttSocketJetty8 =
new org.apache.activemq.transport.ws.jetty8.MQTTSocket("ws://localhost:8080");
assertEquals("ws://localhost:8080", mqttSocketJetty8.getRemoteAddress());
org.apache.activemq.transport.ws.jetty8.MQTTSocket mqttSocketJetty9 =
new org.apache.activemq.transport.ws.jetty8.MQTTSocket("ws://localhost:8080");
assertEquals("ws://localhost:8080", mqttSocketJetty9.getRemoteAddress());
}
}