diff --git a/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOOutputStream.java b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOOutputStream.java index 6e183ff5c2..3bc02c8139 100644 --- a/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOOutputStream.java +++ b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOOutputStream.java @@ -25,6 +25,9 @@ import java.nio.channels.WritableByteChannel; import org.apache.activemq.transport.tcp.TimeStampStream; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; + /** * An optimized buffered outputstream for Tcp * @@ -43,6 +46,8 @@ public class NIOOutputStream extends OutputStream implements TimeStampStream { private boolean closed; private volatile long writeTimestamp = -1;//concurrent reads of this value + private SSLEngine engine; + /** * Constructor * @@ -149,7 +154,16 @@ public class NIOOutputStream extends OutputStream implements TimeStampStream { } protected void write(ByteBuffer data) throws IOException { - int remaining = data.remaining(); + ByteBuffer plain; + if (engine != null) { + plain = ByteBuffer.allocate(engine.getSession().getPacketBufferSize()); + plain.clear(); + engine.wrap(data, plain); + } else { + plain = data; + } + plain.flip(); + int remaining = plain.remaining(); int lastRemaining = remaining - 1; long delay = 1; try { @@ -176,7 +190,7 @@ public class NIOOutputStream extends OutputStream implements TimeStampStream { // Since the write is non-blocking, all the data may not have been // written. - out.write(data); + out.write(plain); remaining = data.remaining(); } } finally { @@ -199,4 +213,7 @@ public class NIOOutputStream extends OutputStream implements TimeStampStream { return writeTimestamp; } + public void setEngine(SSLEngine engine) { + this.engine = engine; + } } diff --git a/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java new file mode 100644 index 0000000000..330587add9 --- /dev/null +++ b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java @@ -0,0 +1,250 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.nio; + +import org.apache.activemq.command.Command; +import org.apache.activemq.openwire.OpenWireFormat; +import org.apache.activemq.util.IOExceptionSupport; +import org.apache.activemq.util.ServiceStopper; +import org.apache.activemq.wireformat.WireFormat; + +import javax.net.SocketFactory; +import javax.net.ssl.*; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.net.Socket; +import java.net.URI; +import java.net.UnknownHostException; +import java.nio.ByteBuffer; + +public class NIOSSLTransport extends NIOTransport { + + protected SSLContext sslContext; + protected SSLEngine sslEngine; + protected SSLSession sslSession; + + + boolean handshakeInProgress = false; + SSLEngineResult.Status status = null; + SSLEngineResult.HandshakeStatus handshakeStatus = null; + + public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { + super(wireFormat, socketFactory, remoteLocation, localLocation); + } + + public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException { + super(wireFormat, socket); + } + + public void setSslContext(SSLContext sslContext) { + this.sslContext = sslContext; + } + + @Override + protected void initializeStreams() throws IOException { + + try { + channel = socket.getChannel(); + channel.configureBlocking(false); + + if (sslContext == null) { + sslContext = SSLContext.getDefault(); + } + + sslEngine = sslContext.createSSLEngine(); + sslEngine.setUseClientMode(false); + sslSession = sslEngine.getSession(); + + inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); + inputBuffer.clear(); + currentBuffer = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); + + NIOOutputStream outputStream = new NIOOutputStream(channel); + outputStream.setEngine(sslEngine); + this.dataOut = new DataOutputStream(outputStream); + this.buffOut = outputStream; + + sslEngine.beginHandshake(); + handshakeStatus = sslEngine.getHandshakeStatus(); + + + doHandshake(); + + } catch (Exception e) { + throw new IOException(e); + } + + } + + protected void finishHandshake() throws Exception { + if (handshakeInProgress) { + handshakeInProgress = false; + nextFrameSize = -1; + + // listen for events telling us when the socket is readable. + selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { + public void onSelect(SelectorSelection selection) { + serviceRead(); + } + + public void onError(SelectorSelection selection, Throwable error) { + if (error instanceof IOException) { + onException((IOException) error); + } else { + onException(IOExceptionSupport.create(error)); + } + } + }); + } + } + + + + protected void serviceRead() { + try { + if (handshakeInProgress) { + doHandshake(); + } + + ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); + plain.position(plain.limit()); + + while (true) { + if (nextFrameSize == -1) { + if (!plain.hasRemaining()) { + plain.clear(); + int readCount = secureRead(plain); + if (readCount == 0) + break; + } + nextFrameSize = plain.getInt(); + if (wireFormat instanceof OpenWireFormat) { + long maxFrameSize = ((OpenWireFormat)wireFormat).getMaxFrameSize(); + if (nextFrameSize > maxFrameSize) { + throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); + } + } + currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); + currentBuffer.putInt(nextFrameSize); + if (currentBuffer.hasRemaining()) { + if (currentBuffer.remaining() >= plain.remaining()) { + currentBuffer.put(plain); + } else { + byte[] fill = new byte[currentBuffer.remaining()]; + plain.get(fill); + currentBuffer.put(fill); + } + } + + if (currentBuffer.hasRemaining()) { + continue; + } else { + currentBuffer.flip(); + Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer))); + doConsume((Command) command); + + nextFrameSize = -1; + } + } + } + + } catch (IOException e) { + onException(e); + } catch (Throwable e) { + onException(IOExceptionSupport.create(e)); + } + + } + + + + private int secureRead(ByteBuffer plain) throws Exception { + int bytesRead = channel.read(inputBuffer); + if (bytesRead == -1) { + sslEngine.closeInbound(); + if (inputBuffer.position() == 0 || + status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { + return -1; + } + } + + plain.clear(); + + inputBuffer.flip(); + SSLEngineResult res; + do { + res = sslEngine.unwrap(inputBuffer, plain); + } while (res.getStatus() == SSLEngineResult.Status.OK && + res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP && + res.bytesProduced() == 0); + + if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { + finishHandshake(); + } + + status = res.getStatus(); + handshakeStatus = res.getHandshakeStatus(); + + //TODO deal with BUFFER_OVERFLOW + + if (status == SSLEngineResult.Status.CLOSED) { + //TODO do shutdown + sslEngine.closeInbound(); + return -1; + } + + inputBuffer.compact(); + plain.flip(); + + return plain.remaining(); + } + + protected void doHandshake() throws Exception { + handshakeInProgress = true; + while (true) { + switch (sslEngine.getHandshakeStatus()) { + case NEED_UNWRAP: + secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); + break; + case NEED_TASK: + //TODO use the pool + Runnable task; + while ((task = sslEngine.getDelegatedTask()) != null) { + task.run(); + } + break; + case NEED_WRAP: + ((NIOOutputStream)buffOut).write(ByteBuffer.allocate(0)); + break; + case FINISHED: + case NOT_HANDSHAKING: + finishHandshake(); + return; + } + } + } + + @Override + protected void doStop(ServiceStopper stopper) throws Exception { + if (channel != null) { + channel.close(); + } + super.doStop(stopper); + } +} diff --git a/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransportFactory.java b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransportFactory.java new file mode 100644 index 0000000000..71cc5db259 --- /dev/null +++ b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransportFactory.java @@ -0,0 +1,132 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.activemq.transport.nio; + +import org.apache.activemq.broker.SslContext; +import org.apache.activemq.transport.Transport; +import org.apache.activemq.transport.TransportServer; +import org.apache.activemq.transport.tcp.SslTransport; +import org.apache.activemq.transport.tcp.SslTransportFactory; +import org.apache.activemq.transport.tcp.TcpTransport; +import org.apache.activemq.transport.tcp.TcpTransportServer; +import org.apache.activemq.util.IOExceptionSupport; +import org.apache.activemq.util.IntrospectionSupport; +import org.apache.activemq.wireformat.WireFormat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ServerSocketFactory; +import javax.net.SocketFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import java.io.IOException; +import java.net.Socket; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.UnknownHostException; +import java.util.Map; + +public class NIOSSLTransportFactory extends NIOTransportFactory { + private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransportFactory.class); + SSLContext context; + + protected TcpTransportServer createTcpTransportServer(URI location, ServerSocketFactory serverSocketFactory) throws IOException, URISyntaxException { + return new TcpTransportServer(this, location, serverSocketFactory) { + protected Transport createTransport(Socket socket, WireFormat format) throws IOException { + NIOSSLTransport transport = new NIOSSLTransport(format, socket); + if (context != null) { + transport.setSslContext(context); + } + return transport; + } + }; + } + + @Override + public TransportServer doBind(URI location) throws IOException { + if (SslContext.getCurrentSslContext() != null) { + try { + context = SslContext.getCurrentSslContext().getSSLContext(); + } catch (Exception e) { + throw new IOException(e); + } + } + return super.doBind(location); + } + + + /** + * Overriding to allow for proper configuration through reflection but delegate to get common + * configuration + */ + public Transport compositeConfigure(Transport transport, WireFormat format, Map options) { + if (transport instanceof SslTransport) { + SslTransport sslTransport = (SslTransport)transport.narrow(SslTransport.class); + IntrospectionSupport.setProperties(sslTransport, options); + } else if (transport instanceof NIOSSLTransport) { + NIOSSLTransport sslTransport = (NIOSSLTransport)transport.narrow(NIOSSLTransport.class); + IntrospectionSupport.setProperties(sslTransport, options); + } + + return super.compositeConfigure(transport, format, options); + } + + /** + * Overriding to use SslTransports. + */ + protected Transport createTransport(URI location, WireFormat wf) throws UnknownHostException, IOException { + + URI localLocation = null; + String path = location.getPath(); + // see if the path is a local URI location + if (path != null && path.length() > 0) { + int localPortIndex = path.indexOf(':'); + try { + Integer.parseInt(path.substring(localPortIndex + 1, path.length())); + String localString = location.getScheme() + ":/" + path; + localLocation = new URI(localString); + } catch (Exception e) { + LOG.warn("path isn't a valid local location for SslTransport to use", e); + } + } + SocketFactory socketFactory = createSocketFactory(); + return new SslTransport(wf, (SSLSocketFactory)socketFactory, location, localLocation, false); + } + + /** + * Creates a new SSL SocketFactory. The given factory will use user-provided + * key and trust managers (if the user provided them). + * + * @return Newly created (Ssl)SocketFactory. + * @throws IOException + */ + protected SocketFactory createSocketFactory() throws IOException { + if( SslContext.getCurrentSslContext()!=null ) { + SslContext ctx = SslContext.getCurrentSslContext(); + try { + return ctx.getSSLContext().getSocketFactory(); + } catch (Exception e) { + throw IOExceptionSupport.create(e); + } + } else { + return SSLSocketFactory.getDefault(); + } + + } + +} diff --git a/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOTransport.java b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOTransport.java index 083b1ccfbf..44af4be003 100644 --- a/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOTransport.java +++ b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOTransport.java @@ -45,11 +45,11 @@ import org.apache.activemq.wireformat.WireFormat; public class NIOTransport extends TcpTransport { // private static final Logger log = LoggerFactory.getLogger(NIOTransport.class); - private SocketChannel channel; - private SelectorSelection selection; - private ByteBuffer inputBuffer; - private ByteBuffer currentBuffer; - private int nextFrameSize; + protected SocketChannel channel; + protected SelectorSelection selection; + protected ByteBuffer inputBuffer; + protected ByteBuffer currentBuffer; + protected int nextFrameSize; public NIOTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { super(wireFormat, socketFactory, remoteLocation, localLocation); @@ -89,7 +89,7 @@ public class NIOTransport extends TcpTransport { this.buffOut = outPutStream; } - private void serviceRead() { + protected void serviceRead() { try { while (true) { diff --git a/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOTransportFactory.java b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOTransportFactory.java index dd068b7376..249e48ba49 100644 --- a/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOTransportFactory.java +++ b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOTransportFactory.java @@ -72,7 +72,7 @@ public class NIOTransportFactory extends TcpTransportFactory { }; } - protected SocketFactory createSocketFactory() { + protected SocketFactory createSocketFactory() throws IOException { return new SocketFactory() { public Socket createSocket() throws IOException { diff --git a/activemq-core/src/main/resources/META-INF/services/org/apache/activemq/transport/nio+ssl b/activemq-core/src/main/resources/META-INF/services/org/apache/activemq/transport/nio+ssl new file mode 100644 index 0000000000..5ad7411a43 --- /dev/null +++ b/activemq-core/src/main/resources/META-INF/services/org/apache/activemq/transport/nio+ssl @@ -0,0 +1,17 @@ +## --------------------------------------------------------------------------- +## Licensed to the Apache Software Foundation (ASF) under one or more +## contributor license agreements. See the NOTICE file distributed with +## this work for additional information regarding copyright ownership. +## The ASF licenses this file to You under the Apache License, Version 2.0 +## (the "License"); you may not use this file except in compliance with +## the License. You may obtain a copy of the License at +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language governing permissions and +## limitations under the License. +## --------------------------------------------------------------------------- +class=org.apache.activemq.transport.nio.NIOSSLTransportFactory \ No newline at end of file diff --git a/activemq-core/src/test/java/org/apache/activemq/transport/nio/NIOSSLTransportBrokerTest.java b/activemq-core/src/test/java/org/apache/activemq/transport/nio/NIOSSLTransportBrokerTest.java new file mode 100644 index 0000000000..4860ec067d --- /dev/null +++ b/activemq-core/src/test/java/org/apache/activemq/transport/nio/NIOSSLTransportBrokerTest.java @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.nio; + +import java.net.URI; +import java.net.URISyntaxException; +import junit.framework.Test; +import junit.textui.TestRunner; +import org.apache.activemq.transport.TransportBrokerTestSupport; + +public class NIOSSLTransportBrokerTest extends TransportBrokerTestSupport { + + public static final String KEYSTORE_TYPE = "jks"; + public static final String PASSWORD = "password"; + public static final String SERVER_KEYSTORE = "src/test/resources/server.keystore"; + public static final String TRUST_KEYSTORE = "src/test/resources/client.keystore"; + + protected String getBindLocation() { + return "nio+ssl://localhost:0"; + } + + @Override + protected URI getBindURI() throws URISyntaxException { + return new URI("nio+ssl://localhost:0"); + } + + protected void setUp() throws Exception { + System.setProperty("javax.net.ssl.trustStore", TRUST_KEYSTORE); + System.setProperty("javax.net.ssl.trustStorePassword", PASSWORD); + System.setProperty("javax.net.ssl.trustStoreType", KEYSTORE_TYPE); + System.setProperty("javax.net.ssl.keyStore", SERVER_KEYSTORE); + System.setProperty("javax.net.ssl.keyStorePassword", PASSWORD); + System.setProperty("javax.net.ssl.keyStoreType", KEYSTORE_TYPE); + //System.setProperty("javax.net.debug", "ssl,handshake,data,trustmanager"); + + maxWait = 10000; + super.setUp(); + } + + @Override + protected void tearDown() throws Exception { + super.tearDown(); + } + + public static Test suite() { + return suite(NIOSSLTransportBrokerTest.class); + } + + public static void main(String[] args) { + TestRunner.run(suite()); + } + +}