diff --git a/activemq-http/pom.xml b/activemq-http/pom.xml index 7370997550..59de71acca 100755 --- a/activemq-http/pom.xml +++ b/activemq-http/pom.xml @@ -115,6 +115,11 @@ provided true + + org.mockito + mockito-core + test + diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/util/HttpTransportUtils.java b/activemq-http/src/main/java/org/apache/activemq/transport/util/HttpTransportUtils.java new file mode 100644 index 0000000000..55340e5346 --- /dev/null +++ b/activemq-http/src/main/java/org/apache/activemq/transport/util/HttpTransportUtils.java @@ -0,0 +1,20 @@ +package org.apache.activemq.transport.util; + +import javax.servlet.http.HttpServletRequest; + +public class HttpTransportUtils { + + public static String generateWsRemoteAddress(HttpServletRequest request) { + if (request == null) { + throw new IllegalArgumentException("HttpServletRequest must not be null."); + } + + StringBuilder remoteAddress = new StringBuilder(); + String scheme = request.getScheme(); + remoteAddress.append(scheme != null && scheme.toLowerCase().equals("https") ? "wss://" : "ws://"); + remoteAddress.append(request.getRemoteAddr()); + remoteAddress.append(":"); + remoteAddress.append(request.getRemotePort()); + return remoteAddress.toString(); + } +} diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java new file mode 100644 index 0000000000..406741c636 --- /dev/null +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java @@ -0,0 +1,76 @@ +package org.apache.activemq.transport.ws; + +import java.security.cert.X509Certificate; +import java.util.concurrent.CountDownLatch; + +import org.apache.activemq.broker.BrokerService; +import org.apache.activemq.broker.BrokerServiceAware; +import org.apache.activemq.transport.TransportSupport; +import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor; +import org.apache.activemq.transport.mqtt.MQTTProtocolConverter; +import org.apache.activemq.transport.mqtt.MQTTTransport; +import org.apache.activemq.transport.mqtt.MQTTWireFormat; +import org.apache.activemq.util.ServiceStopper; + +public abstract class AbstractMQTTSocket extends TransportSupport implements MQTTTransport, BrokerServiceAware { + + protected MQTTWireFormat wireFormat = new MQTTWireFormat(); + protected final CountDownLatch socketTransportStarted = new CountDownLatch(1); + protected MQTTProtocolConverter protocolConverter = null; + private BrokerService brokerService; + protected final String remoteAddress; + + public AbstractMQTTSocket(String remoteAddress) { + super(); + this.remoteAddress = remoteAddress; + } + + protected boolean transportStartedAtLeastOnce() { + return socketTransportStarted.getCount() == 0; + } + + protected void doStart() throws Exception { + socketTransportStarted.countDown(); + } + + @Override + protected void doStop(ServiceStopper stopper) throws Exception { + } + + protected MQTTProtocolConverter getProtocolConverter() { + if( protocolConverter == null ) { + protocolConverter = new MQTTProtocolConverter(this, brokerService); + } + return protocolConverter; + } + + @Override + public int getReceiveCounter() { + return 0; + } + + @Override + public X509Certificate[] getPeerCertificates() { + return new X509Certificate[0]; + } + + @Override + public MQTTInactivityMonitor getInactivityMonitor() { + return null; + } + + @Override + public MQTTWireFormat getWireFormat() { + return wireFormat; + } + + @Override + public String getRemoteAddress() { + return remoteAddress; + } + + @Override + public void setBrokerService(BrokerService brokerService) { + this.brokerService = brokerService; + } +} diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java index 739e2fcf0a..4ffa6c9ad3 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java @@ -45,6 +45,13 @@ public abstract class AbstractStompSocket extends TransportSupport implements St protected final CountDownLatch socketTransportStarted = new CountDownLatch(1); protected final StompInactivityMonitor stompInactivityMonitor = new StompInactivityMonitor(this, wireFormat); protected volatile int receiveCounter; + protected final String remoteAddress; + + + public AbstractStompSocket(String remoteAddress) { + super(); + this.remoteAddress = remoteAddress; + } @Override public void oneway(Object command) throws IOException { @@ -100,7 +107,7 @@ public abstract class AbstractStompSocket extends TransportSupport implements St @Override public String getRemoteAddress() { - return "StompSocket_" + this.hashCode(); + return remoteAddress; } @Override diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/MQTTSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/MQTTSocket.java index 58e9134f42..43f08e41da 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/MQTTSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/MQTTSocket.java @@ -16,35 +16,26 @@ */ package org.apache.activemq.transport.ws.jetty8; -import org.apache.activemq.broker.BrokerService; -import org.apache.activemq.broker.BrokerServiceAware; +import java.io.IOException; + import org.apache.activemq.command.Command; -import org.apache.activemq.transport.TransportSupport; -import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor; -import org.apache.activemq.transport.mqtt.MQTTProtocolConverter; -import org.apache.activemq.transport.mqtt.MQTTTransport; -import org.apache.activemq.transport.mqtt.MQTTWireFormat; +import org.apache.activemq.transport.ws.AbstractMQTTSocket; import org.apache.activemq.util.ByteSequence; import org.apache.activemq.util.IOExceptionSupport; -import org.apache.activemq.util.ServiceStopper; import org.eclipse.jetty.websocket.WebSocket; import org.fusesource.mqtt.codec.DISCONNECT; import org.fusesource.mqtt.codec.MQTTFrame; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.security.cert.X509Certificate; -import java.util.concurrent.CountDownLatch; - -public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryMessage, MQTTTransport, BrokerServiceAware { +public class MQTTSocket extends AbstractMQTTSocket implements WebSocket.OnBinaryMessage { private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class); Connection outbound; - MQTTProtocolConverter protocolConverter = null; - MQTTWireFormat wireFormat = new MQTTWireFormat(); - private final CountDownLatch socketTransportStarted = new CountDownLatch(1); - private BrokerService brokerService; + + public MQTTSocket(String remoteAddress) { + super(remoteAddress); + } @Override public void onMessage(byte[] bytes, int offset, int length) { @@ -65,12 +56,6 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM } } - private MQTTProtocolConverter getProtocolConverter() { - if( protocolConverter == null ) { - protocolConverter = new MQTTProtocolConverter(this, brokerService); - } - return protocolConverter; - } @Override public void onOpen(Connection connection) { @@ -86,28 +71,6 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM } } - protected void doStart() throws Exception { - socketTransportStarted.countDown(); - } - - @Override - protected void doStop(ServiceStopper stopper) throws Exception { - } - - private boolean transportStartedAtLeastOnce() { - return socketTransportStarted.getCount() == 0; - } - - @Override - public int getReceiveCounter() { - return 0; - } - - @Override - public String getRemoteAddress() { - return "MQTTSocket_" + this.hashCode(); - } - @Override public void oneway(Object command) throws IOException { try { @@ -128,23 +91,4 @@ public class MQTTSocket extends TransportSupport implements WebSocket.OnBinaryM outbound.sendMessage(bytes.getData(), 0, bytes.getLength()); } - @Override - public X509Certificate[] getPeerCertificates() { - return new X509Certificate[0]; - } - - @Override - public MQTTInactivityMonitor getInactivityMonitor() { - return null; - } - - @Override - public MQTTWireFormat getWireFormat() { - return wireFormat; - } - - @Override - public void setBrokerService(BrokerService brokerService) { - this.brokerService = brokerService; - } } diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/StompSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/StompSocket.java index 23357bd6eb..a2d07b99c4 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/StompSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/StompSocket.java @@ -28,12 +28,16 @@ import org.slf4j.LoggerFactory; /** * Implements web socket and mediates between servlet and the broker */ -class StompSocket extends AbstractStompSocket implements WebSocket.OnTextMessage { +public class StompSocket extends AbstractStompSocket implements WebSocket.OnTextMessage { private static final Logger LOG = LoggerFactory.getLogger(StompSocket.class); private Connection outbound; + public StompSocket(String remoteAddress) { + super(remoteAddress); + } + @Override public void handleStopped() throws IOException { if (outbound != null && outbound.isOpen()) { diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/WSServlet.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/WSServlet.java index ac589b7df5..c5cb706434 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/WSServlet.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/WSServlet.java @@ -24,6 +24,7 @@ import javax.servlet.http.HttpServletResponse; import org.apache.activemq.transport.Transport; import org.apache.activemq.transport.TransportAcceptListener; +import org.apache.activemq.transport.util.HttpTransportUtils; import org.eclipse.jetty.websocket.WebSocket; import org.eclipse.jetty.websocket.WebSocketServlet; @@ -54,9 +55,9 @@ public class WSServlet extends WebSocketServlet { public WebSocket doWebSocketConnect(HttpServletRequest request, String protocol) { WebSocket socket; if (protocol != null && protocol.startsWith("mqtt")) { - socket = new MQTTSocket(); + socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(request)); } else { - socket = new StompSocket(); + socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(request)); } listener.onAccept((Transport) socket); return socket; diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java index 4d7dac3a2e..ef7631a057 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java @@ -16,17 +16,13 @@ */ package org.apache.activemq.transport.ws.jetty9; -import org.apache.activemq.broker.BrokerService; -import org.apache.activemq.broker.BrokerServiceAware; +import java.io.IOException; +import java.nio.ByteBuffer; + import org.apache.activemq.command.Command; -import org.apache.activemq.transport.TransportSupport; -import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor; -import org.apache.activemq.transport.mqtt.MQTTProtocolConverter; -import org.apache.activemq.transport.mqtt.MQTTTransport; -import org.apache.activemq.transport.mqtt.MQTTWireFormat; +import org.apache.activemq.transport.ws.AbstractMQTTSocket; import org.apache.activemq.util.ByteSequence; import org.apache.activemq.util.IOExceptionSupport; -import org.apache.activemq.util.ServiceStopper; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.WebSocketListener; import org.fusesource.mqtt.codec.DISCONNECT; @@ -34,47 +30,13 @@ import org.fusesource.mqtt.codec.MQTTFrame; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.security.cert.X509Certificate; -import java.util.concurrent.CountDownLatch; - -public class MQTTSocket extends TransportSupport implements WebSocketListener, MQTTTransport, BrokerServiceAware { +public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener { private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class); Session session; - MQTTProtocolConverter protocolConverter = null; - MQTTWireFormat wireFormat = new MQTTWireFormat(); - private final CountDownLatch socketTransportStarted = new CountDownLatch(1); - private BrokerService brokerService; - private MQTTProtocolConverter getProtocolConverter() { - if( protocolConverter == null ) { - protocolConverter = new MQTTProtocolConverter(this, brokerService); - } - return protocolConverter; - } - - protected void doStart() throws Exception { - socketTransportStarted.countDown(); - } - - @Override - protected void doStop(ServiceStopper stopper) throws Exception { - } - - private boolean transportStartedAtLeastOnce() { - return socketTransportStarted.getCount() == 0; - } - - @Override - public int getReceiveCounter() { - return 0; - } - - @Override - public String getRemoteAddress() { - return "MQTTSocket_" + this.hashCode(); + public MQTTSocket(String remoteAddress) { + super(remoteAddress); } @Override @@ -97,26 +59,6 @@ public class MQTTSocket extends TransportSupport implements WebSocketListener, session.getRemote().sendBytes(ByteBuffer.wrap(bytes.getData(), 0, bytes.getLength())); } - @Override - public X509Certificate[] getPeerCertificates() { - return new X509Certificate[0]; - } - - @Override - public MQTTInactivityMonitor getInactivityMonitor() { - return null; - } - - @Override - public MQTTWireFormat getWireFormat() { - return wireFormat; - } - - @Override - public void setBrokerService(BrokerService brokerService) { - this.brokerService = brokerService; - } - @Override public void onWebSocketBinary(byte[] bytes, int offset, int length) { if (!transportStartedAtLeastOnce()) { @@ -142,7 +84,7 @@ public class MQTTSocket extends TransportSupport implements WebSocketListener, getProtocolConverter().onMQTTCommand(new DISCONNECT().encode()); } catch (Exception e) { LOG.warn("Failed to close WebSocket", e); - } + } } @Override @@ -152,10 +94,10 @@ public class MQTTSocket extends TransportSupport implements WebSocketListener, @Override public void onWebSocketError(Throwable arg0) { - + } @Override - public void onWebSocketText(String arg0) { + public void onWebSocketText(String arg0) { } } diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/StompSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/StompSocket.java index be7dc30c77..b7edcbe59f 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/StompSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/StompSocket.java @@ -29,12 +29,16 @@ import org.slf4j.LoggerFactory; /** * Implements web socket and mediates between servlet and the broker */ -class StompSocket extends AbstractStompSocket implements WebSocketListener { +public class StompSocket extends AbstractStompSocket implements WebSocketListener { private static final Logger LOG = LoggerFactory.getLogger(StompSocket.class); private Session session; + public StompSocket(String remoteAddress) { + super(remoteAddress); + } + @Override public void sendToStomp(StompFrame command) throws IOException { session.getRemote().sendString(command.format()); diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java index 1bc744b2f8..7684318815 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java @@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletResponse; import org.apache.activemq.transport.Transport; import org.apache.activemq.transport.TransportAcceptListener; +import org.apache.activemq.transport.util.HttpTransportUtils; import org.eclipse.jetty.websocket.api.WebSocketListener; import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; @@ -62,10 +63,10 @@ public class WSServlet extends WebSocketServlet { public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) { WebSocketListener socket; if (req.getSubProtocols().contains("mqtt")) { - socket = new MQTTSocket(); + socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); resp.setAcceptedSubProtocol("mqtt"); } else { - socket = new StompSocket(); + socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); resp.setAcceptedSubProtocol("stomp"); } listener.onAccept((Transport) socket); diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/util/HttpTransportUtilsTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/util/HttpTransportUtilsTest.java new file mode 100644 index 0000000000..51fcd048b4 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/util/HttpTransportUtilsTest.java @@ -0,0 +1,37 @@ +package org.apache.activemq.transport.util; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import javax.servlet.http.HttpServletRequest; + +import org.junit.Test; + +public class HttpTransportUtilsTest { + + @Test + public void testGenerateWsRemoteAddress() { + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getScheme()).thenReturn("http"); + when(request.getRemoteAddr()).thenReturn("localhost"); + when(request.getRemotePort()).thenReturn(8080); + + assertEquals("ws://localhost:8080", HttpTransportUtils.generateWsRemoteAddress(request)); + } + + @Test + public void testGenerateWssRemoteAddress() { + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getScheme()).thenReturn("https"); + when(request.getRemoteAddr()).thenReturn("localhost"); + when(request.getRemotePort()).thenReturn(8443); + + assertEquals("wss://localhost:8443", HttpTransportUtils.generateWsRemoteAddress(request)); + } + + @Test(expected=IllegalArgumentException.class) + public void testNullHttpServleRequest() { + HttpTransportUtils.generateWsRemoteAddress(null); + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/SocketTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/SocketTest.java new file mode 100644 index 0000000000..871029a7af --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/SocketTest.java @@ -0,0 +1,37 @@ +package org.apache.activemq.transport.ws; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +public class SocketTest { + + @Test + public void testStompSocketRemoteAddress() { + + org.apache.activemq.transport.ws.jetty8.StompSocket stompSocketJetty8 = + new org.apache.activemq.transport.ws.jetty8.StompSocket("ws://localhost:8080"); + + assertEquals("ws://localhost:8080", stompSocketJetty8.getRemoteAddress()); + + org.apache.activemq.transport.ws.jetty9.StompSocket stompSocketJetty9 = + new org.apache.activemq.transport.ws.jetty9.StompSocket("ws://localhost:8080"); + + assertEquals("ws://localhost:8080", stompSocketJetty9.getRemoteAddress()); + } + + @Test + public void testMqttSocketRemoteAddress() { + + org.apache.activemq.transport.ws.jetty8.MQTTSocket mqttSocketJetty8 = + new org.apache.activemq.transport.ws.jetty8.MQTTSocket("ws://localhost:8080"); + + assertEquals("ws://localhost:8080", mqttSocketJetty8.getRemoteAddress()); + + org.apache.activemq.transport.ws.jetty8.MQTTSocket mqttSocketJetty9 = + new org.apache.activemq.transport.ws.jetty8.MQTTSocket("ws://localhost:8080"); + + assertEquals("ws://localhost:8080", mqttSocketJetty9.getRemoteAddress()); + } + +} diff --git a/activemq-unit-tests/src/test/java/org/apache/activemq/network/NetworkRouteTest.java b/activemq-unit-tests/src/test/java/org/apache/activemq/network/NetworkRouteTest.java index 4cd97b0cbc..0f50afefb9 100644 --- a/activemq-unit-tests/src/test/java/org/apache/activemq/network/NetworkRouteTest.java +++ b/activemq-unit-tests/src/test/java/org/apache/activemq/network/NetworkRouteTest.java @@ -232,6 +232,7 @@ public class NetworkRouteTest { msg.setDestination(new ActiveMQTopic("test")); msgDispatch = new MessageDispatch(); msgDispatch.setMessage(msg); + msgDispatch.setDestination(msg.getDestination()); ConsumerInfo path1 = new ConsumerInfo(); path1.setDestination(msg.getDestination());