mirror of https://github.com/apache/activemq.git
StompSocket and MQTTSocket will now return the appropriate web socket remote address based on the HttpRequestServlet that initialized the web socket connection.
This commit is contained in:
parent
bbf288b12c
commit
be10b866a7
|
@ -115,6 +115,11 @@
|
|||
<scope>provided</scope>
|
||||
<optional>true</optional>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
<plugins>
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package org.apache.activemq.transport.util;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
public class HttpTransportUtils {
|
||||
|
||||
public static String generateWsRemoteAddress(HttpServletRequest request) {
|
||||
if (request == null) {
|
||||
throw new IllegalArgumentException("HttpServletRequest must not be null.");
|
||||
}
|
||||
|
||||
StringBuilder remoteAddress = new StringBuilder();
|
||||
String scheme = request.getScheme();
|
||||
remoteAddress.append(scheme != null && scheme.toLowerCase().equals("https") ? "wss://" : "ws://");
|
||||
remoteAddress.append(request.getRemoteAddr());
|
||||
remoteAddress.append(":");
|
||||
remoteAddress.append(request.getRemotePort());
|
||||
return remoteAddress.toString();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package org.apache.activemq.transport.ws;
|
||||
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
import org.apache.activemq.broker.BrokerService;
|
||||
import org.apache.activemq.broker.BrokerServiceAware;
|
||||
import org.apache.activemq.transport.TransportSupport;
|
||||
import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor;
|
||||
import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
|
||||
import org.apache.activemq.transport.mqtt.MQTTTransport;
|
||||
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
|
||||
import org.apache.activemq.util.ServiceStopper;
|
||||
|
||||
public abstract class AbstractMQTTSocket extends TransportSupport implements MQTTTransport, BrokerServiceAware {
|
||||
|
||||
protected MQTTWireFormat wireFormat = new MQTTWireFormat();
|
||||
protected final CountDownLatch socketTransportStarted = new CountDownLatch(1);
|
||||
protected MQTTProtocolConverter protocolConverter = null;
|
||||
private BrokerService brokerService;
|
||||
protected final String remoteAddress;
|
||||
|
||||
public AbstractMQTTSocket(String remoteAddress) {
|
||||
super();
|
||||
this.remoteAddress = remoteAddress;
|
||||
}
|
||||
|
||||
protected boolean transportStartedAtLeastOnce() {
|
||||
return socketTransportStarted.getCount() == 0;
|
||||
}
|
||||
|
||||
protected void doStart() throws Exception {
|
||||
socketTransportStarted.countDown();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doStop(ServiceStopper stopper) throws Exception {
|
||||
}
|
||||
|
||||
protected MQTTProtocolConverter getProtocolConverter() {
|
||||
if( protocolConverter == null ) {
|
||||
protocolConverter = new MQTTProtocolConverter(this, brokerService);
|
||||
}
|
||||
return protocolConverter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getReceiveCounter() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public X509Certificate[] getPeerCertificates() {
|
||||
return new X509Certificate[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public MQTTInactivityMonitor getInactivityMonitor() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MQTTWireFormat getWireFormat() {
|
||||
return wireFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getRemoteAddress() {
|
||||
return remoteAddress;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBrokerService(BrokerService brokerService) {
|
||||
this.brokerService = brokerService;
|
||||
}
|
||||
}
|
|
@ -45,6 +45,13 @@ public abstract class AbstractStompSocket extends TransportSupport implements St
|
|||
protected final CountDownLatch socketTransportStarted = new CountDownLatch(1);
|
||||
protected final StompInactivityMonitor stompInactivityMonitor = new StompInactivityMonitor(this, wireFormat);
|
||||
protected volatile int receiveCounter;
|
||||
protected final String remoteAddress;
|
||||
|
||||
|
||||
public AbstractStompSocket(String remoteAddress) {
|
||||
super();
|
||||
this.remoteAddress = remoteAddress;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void oneway(Object command) throws IOException {
|
||||
|
@ -100,7 +107,7 @@ public abstract class AbstractStompSocket extends TransportSupport implements St
|
|||
|
||||
@Override
|
||||
public String getRemoteAddress() {
|
||||
return "StompSocket_" + this.hashCode();
|
||||
return remoteAddress;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,35 +16,26 @@
|
|||
*/
|
||||
package org.apache.activemq.transport.ws.jetty8;
|
||||
|
||||
import org.apache.activemq.broker.BrokerService;
|
||||
import org.apache.activemq.broker.BrokerServiceAware;
|
||||
import java.io.IOException;
|
||||
|
||||
import org.apache.activemq.command.Command;
|
||||
import org.apache.activemq.transport.TransportSupport;
|
||||
import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor;
|
||||
import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
|
||||
import org.apache.activemq.transport.mqtt.MQTTTransport;
|
||||
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
|
||||
import org.apache.activemq.transport.ws.AbstractMQTTSocket;
|
||||
import org.apache.activemq.util.ByteSequence;
|
||||
import org.apache.activemq.util.IOExceptionSupport;
|
||||
import org.apache.activemq.util.ServiceStopper;
|
||||
import org.eclipse.jetty.websocket.WebSocket;
|
||||
import org.fusesource.mqtt.codec.DISCONNECT;
|
||||
import org.fusesource.mqtt.codec.MQTTFrame;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryMessage, MQTTTransport, BrokerServiceAware {
|
||||
public class MQTTSocket extends AbstractMQTTSocket implements WebSocket.OnBinaryMessage {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class);
|
||||
Connection outbound;
|
||||
MQTTProtocolConverter protocolConverter = null;
|
||||
MQTTWireFormat wireFormat = new MQTTWireFormat();
|
||||
private final CountDownLatch socketTransportStarted = new CountDownLatch(1);
|
||||
private BrokerService brokerService;
|
||||
|
||||
public MQTTSocket(String remoteAddress) {
|
||||
super(remoteAddress);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(byte[] bytes, int offset, int length) {
|
||||
|
@ -65,12 +56,6 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM
|
|||
}
|
||||
}
|
||||
|
||||
private MQTTProtocolConverter getProtocolConverter() {
|
||||
if( protocolConverter == null ) {
|
||||
protocolConverter = new MQTTProtocolConverter(this, brokerService);
|
||||
}
|
||||
return protocolConverter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onOpen(Connection connection) {
|
||||
|
@ -86,28 +71,6 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM
|
|||
}
|
||||
}
|
||||
|
||||
protected void doStart() throws Exception {
|
||||
socketTransportStarted.countDown();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doStop(ServiceStopper stopper) throws Exception {
|
||||
}
|
||||
|
||||
private boolean transportStartedAtLeastOnce() {
|
||||
return socketTransportStarted.getCount() == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getReceiveCounter() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getRemoteAddress() {
|
||||
return "MQTTSocket_" + this.hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void oneway(Object command) throws IOException {
|
||||
try {
|
||||
|
@ -128,23 +91,4 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM
|
|||
outbound.sendMessage(bytes.getData(), 0, bytes.getLength());
|
||||
}
|
||||
|
||||
@Override
|
||||
public X509Certificate[] getPeerCertificates() {
|
||||
return new X509Certificate[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public MQTTInactivityMonitor getInactivityMonitor() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MQTTWireFormat getWireFormat() {
|
||||
return wireFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBrokerService(BrokerService brokerService) {
|
||||
this.brokerService = brokerService;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,12 +28,16 @@ import org.slf4j.LoggerFactory;
|
|||
/**
|
||||
* Implements web socket and mediates between servlet and the broker
|
||||
*/
|
||||
class StompSocket extends AbstractStompSocket implements WebSocket.OnTextMessage {
|
||||
public class StompSocket extends AbstractStompSocket implements WebSocket.OnTextMessage {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(StompSocket.class);
|
||||
|
||||
private Connection outbound;
|
||||
|
||||
public StompSocket(String remoteAddress) {
|
||||
super(remoteAddress);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleStopped() throws IOException {
|
||||
if (outbound != null && outbound.isOpen()) {
|
||||
|
|
|
@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletResponse;
|
|||
|
||||
import org.apache.activemq.transport.Transport;
|
||||
import org.apache.activemq.transport.TransportAcceptListener;
|
||||
import org.apache.activemq.transport.util.HttpTransportUtils;
|
||||
import org.eclipse.jetty.websocket.WebSocket;
|
||||
import org.eclipse.jetty.websocket.WebSocketServlet;
|
||||
|
||||
|
@ -54,9 +55,9 @@ public class WSServlet extends WebSocketServlet {
|
|||
public WebSocket doWebSocketConnect(HttpServletRequest request, String protocol) {
|
||||
WebSocket socket;
|
||||
if (protocol != null && protocol.startsWith("mqtt")) {
|
||||
socket = new MQTTSocket();
|
||||
socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(request));
|
||||
} else {
|
||||
socket = new StompSocket();
|
||||
socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(request));
|
||||
}
|
||||
listener.onAccept((Transport) socket);
|
||||
return socket;
|
||||
|
|
|
@ -16,17 +16,13 @@
|
|||
*/
|
||||
package org.apache.activemq.transport.ws.jetty9;
|
||||
|
||||
import org.apache.activemq.broker.BrokerService;
|
||||
import org.apache.activemq.broker.BrokerServiceAware;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
import org.apache.activemq.command.Command;
|
||||
import org.apache.activemq.transport.TransportSupport;
|
||||
import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor;
|
||||
import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
|
||||
import org.apache.activemq.transport.mqtt.MQTTTransport;
|
||||
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
|
||||
import org.apache.activemq.transport.ws.AbstractMQTTSocket;
|
||||
import org.apache.activemq.util.ByteSequence;
|
||||
import org.apache.activemq.util.IOExceptionSupport;
|
||||
import org.apache.activemq.util.ServiceStopper;
|
||||
import org.eclipse.jetty.websocket.api.Session;
|
||||
import org.eclipse.jetty.websocket.api.WebSocketListener;
|
||||
import org.fusesource.mqtt.codec.DISCONNECT;
|
||||
|
@ -34,47 +30,13 @@ import org.fusesource.mqtt.codec.MQTTFrame;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
public class MQTTSocket extends TransportSupport implements WebSocketListener, MQTTTransport, BrokerServiceAware {
|
||||
public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class);
|
||||
Session session;
|
||||
MQTTProtocolConverter protocolConverter = null;
|
||||
MQTTWireFormat wireFormat = new MQTTWireFormat();
|
||||
private final CountDownLatch socketTransportStarted = new CountDownLatch(1);
|
||||
private BrokerService brokerService;
|
||||
|
||||
private MQTTProtocolConverter getProtocolConverter() {
|
||||
if( protocolConverter == null ) {
|
||||
protocolConverter = new MQTTProtocolConverter(this, brokerService);
|
||||
}
|
||||
return protocolConverter;
|
||||
}
|
||||
|
||||
protected void doStart() throws Exception {
|
||||
socketTransportStarted.countDown();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doStop(ServiceStopper stopper) throws Exception {
|
||||
}
|
||||
|
||||
private boolean transportStartedAtLeastOnce() {
|
||||
return socketTransportStarted.getCount() == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getReceiveCounter() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getRemoteAddress() {
|
||||
return "MQTTSocket_" + this.hashCode();
|
||||
public MQTTSocket(String remoteAddress) {
|
||||
super(remoteAddress);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -97,26 +59,6 @@ public class MQTTSocket extends TransportSupport implements WebSocketListener,
|
|||
session.getRemote().sendBytes(ByteBuffer.wrap(bytes.getData(), 0, bytes.getLength()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public X509Certificate[] getPeerCertificates() {
|
||||
return new X509Certificate[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public MQTTInactivityMonitor getInactivityMonitor() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MQTTWireFormat getWireFormat() {
|
||||
return wireFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBrokerService(BrokerService brokerService) {
|
||||
this.brokerService = brokerService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onWebSocketBinary(byte[] bytes, int offset, int length) {
|
||||
if (!transportStartedAtLeastOnce()) {
|
||||
|
|
|
@ -29,12 +29,16 @@ import org.slf4j.LoggerFactory;
|
|||
/**
|
||||
* Implements web socket and mediates between servlet and the broker
|
||||
*/
|
||||
class StompSocket extends AbstractStompSocket implements WebSocketListener {
|
||||
public class StompSocket extends AbstractStompSocket implements WebSocketListener {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(StompSocket.class);
|
||||
|
||||
private Session session;
|
||||
|
||||
public StompSocket(String remoteAddress) {
|
||||
super(remoteAddress);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendToStomp(StompFrame command) throws IOException {
|
||||
session.getRemote().sendString(command.format());
|
||||
|
|
|
@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletResponse;
|
|||
|
||||
import org.apache.activemq.transport.Transport;
|
||||
import org.apache.activemq.transport.TransportAcceptListener;
|
||||
import org.apache.activemq.transport.util.HttpTransportUtils;
|
||||
import org.eclipse.jetty.websocket.api.WebSocketListener;
|
||||
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
|
||||
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
|
||||
|
@ -62,10 +63,10 @@ public class WSServlet extends WebSocketServlet {
|
|||
public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) {
|
||||
WebSocketListener socket;
|
||||
if (req.getSubProtocols().contains("mqtt")) {
|
||||
socket = new MQTTSocket();
|
||||
socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
|
||||
resp.setAcceptedSubProtocol("mqtt");
|
||||
} else {
|
||||
socket = new StompSocket();
|
||||
socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
|
||||
resp.setAcceptedSubProtocol("stomp");
|
||||
}
|
||||
listener.onAccept((Transport) socket);
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
package org.apache.activemq.transport.util;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class HttpTransportUtilsTest {
|
||||
|
||||
@Test
|
||||
public void testGenerateWsRemoteAddress() {
|
||||
HttpServletRequest request = mock(HttpServletRequest.class);
|
||||
when(request.getScheme()).thenReturn("http");
|
||||
when(request.getRemoteAddr()).thenReturn("localhost");
|
||||
when(request.getRemotePort()).thenReturn(8080);
|
||||
|
||||
assertEquals("ws://localhost:8080", HttpTransportUtils.generateWsRemoteAddress(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGenerateWssRemoteAddress() {
|
||||
HttpServletRequest request = mock(HttpServletRequest.class);
|
||||
when(request.getScheme()).thenReturn("https");
|
||||
when(request.getRemoteAddr()).thenReturn("localhost");
|
||||
when(request.getRemotePort()).thenReturn(8443);
|
||||
|
||||
assertEquals("wss://localhost:8443", HttpTransportUtils.generateWsRemoteAddress(request));
|
||||
}
|
||||
|
||||
@Test(expected=IllegalArgumentException.class)
|
||||
public void testNullHttpServleRequest() {
|
||||
HttpTransportUtils.generateWsRemoteAddress(null);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package org.apache.activemq.transport.ws;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class SocketTest {
|
||||
|
||||
@Test
|
||||
public void testStompSocketRemoteAddress() {
|
||||
|
||||
org.apache.activemq.transport.ws.jetty8.StompSocket stompSocketJetty8 =
|
||||
new org.apache.activemq.transport.ws.jetty8.StompSocket("ws://localhost:8080");
|
||||
|
||||
assertEquals("ws://localhost:8080", stompSocketJetty8.getRemoteAddress());
|
||||
|
||||
org.apache.activemq.transport.ws.jetty9.StompSocket stompSocketJetty9 =
|
||||
new org.apache.activemq.transport.ws.jetty9.StompSocket("ws://localhost:8080");
|
||||
|
||||
assertEquals("ws://localhost:8080", stompSocketJetty9.getRemoteAddress());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMqttSocketRemoteAddress() {
|
||||
|
||||
org.apache.activemq.transport.ws.jetty8.MQTTSocket mqttSocketJetty8 =
|
||||
new org.apache.activemq.transport.ws.jetty8.MQTTSocket("ws://localhost:8080");
|
||||
|
||||
assertEquals("ws://localhost:8080", mqttSocketJetty8.getRemoteAddress());
|
||||
|
||||
org.apache.activemq.transport.ws.jetty8.MQTTSocket mqttSocketJetty9 =
|
||||
new org.apache.activemq.transport.ws.jetty8.MQTTSocket("ws://localhost:8080");
|
||||
|
||||
assertEquals("ws://localhost:8080", mqttSocketJetty9.getRemoteAddress());
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue