diff --git a/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java index 280fd81f64..913ebecaec 100644 --- a/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java +++ b/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java @@ -41,8 +41,12 @@ import org.apache.activemq.thread.TaskRunnerFactory; import org.apache.activemq.util.IOExceptionSupport; import org.apache.activemq.util.ServiceStopper; import org.apache.activemq.wireformat.WireFormat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -public class NIOSSLTransport extends NIOTransport { +public class NIOSSLTransport extends NIOTransport { + + private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class); protected boolean needClientAuth; protected boolean wantClientAuth; @@ -79,15 +83,36 @@ public class NIOSSLTransport extends NIOTransport { sslContext = SSLContext.getDefault(); } + String remoteHost = null; + int remotePort = -1; + + try { + URI remoteAddress = new URI(this.getRemoteAddress()); + remoteHost = remoteAddress.getHost(); + remotePort = remoteAddress.getPort(); + } catch (Exception e) { + } + // initialize engine, the initial sslSession we get will need to be // updated once the ssl handshake process is completed. - sslEngine = sslContext.createSSLEngine(); + if (remoteHost != null && remotePort != -1) { + sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); + } else { + sslEngine = sslContext.createSSLEngine(); + } + sslEngine.setUseClientMode(false); if (enabledCipherSuites != null) { sslEngine.setEnabledCipherSuites(enabledCipherSuites); } - sslEngine.setNeedClientAuth(needClientAuth); - sslEngine.setWantClientAuth(wantClientAuth); + + if (wantClientAuth) { + sslEngine.setWantClientAuth(wantClientAuth); + } + + if (needClientAuth) { + sslEngine.setNeedClientAuth(needClientAuth); + } sslSession = sslEngine.getSession(); @@ -107,31 +132,31 @@ public class NIOSSLTransport extends NIOTransport { } } - protected void finishHandshake() throws Exception { - if (handshakeInProgress) { - handshakeInProgress = false; - nextFrameSize = -1; + protected void finishHandshake() throws Exception { + if (handshakeInProgress) { + handshakeInProgress = false; + nextFrameSize = -1; - // Once handshake completes we need to ask for the now real sslSession - // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the - // cipher suite. - sslSession = sslEngine.getSession(); + // Once handshake completes we need to ask for the now real sslSession + // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the + // cipher suite. + sslSession = sslEngine.getSession(); - // listen for events telling us when the socket is readable. - selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { - public void onSelect(SelectorSelection selection) { - serviceRead(); - } + // listen for events telling us when the socket is readable. + selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { + public void onSelect(SelectorSelection selection) { + serviceRead(); + } - public void onError(SelectorSelection selection, Throwable error) { - if (error instanceof IOException) { - onException((IOException) error); - } else { - onException(IOExceptionSupport.create(error)); - } - } - }); - } + public void onError(SelectorSelection selection, Throwable error) { + if (error instanceof IOException) { + onException((IOException) error); + } else { + onException(IOExceptionSupport.create(error)); + } + } + }); + } } protected void serviceRead() { @@ -143,7 +168,7 @@ public class NIOSSLTransport extends NIOTransport { ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); plain.position(plain.limit()); - while(true) { + while (true) { if (!plain.hasRemaining()) { if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { @@ -153,12 +178,11 @@ public class NIOSSLTransport extends NIOTransport { } int readCount = secureRead(plain); - if (readCount == 0) break; // channel is closed, cleanup - if (readCount== -1) { + if (readCount == -1) { onException(new EOFException()); selection.close(); break; @@ -181,7 +205,8 @@ public class NIOSSLTransport extends NIOTransport { if (wireFormat instanceof OpenWireFormat) { long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize(); if (nextFrameSize > maxFrameSize) { - throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); + throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + + " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); } } currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); @@ -213,8 +238,7 @@ public class NIOSSLTransport extends NIOTransport { if (bytesRead == -1) { sslEngine.closeInbound(); - if (inputBuffer.position() == 0 || - status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { + if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { return -1; } } @@ -226,18 +250,17 @@ public class NIOSSLTransport extends NIOTransport { SSLEngineResult res; do { res = sslEngine.unwrap(inputBuffer, plain); - } while (res.getStatus() == SSLEngineResult.Status.OK && - res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP && - res.bytesProduced() == 0); + } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP + && res.bytesProduced() == 0); if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { - finishHandshake(); + finishHandshake(); } status = res.getStatus(); handshakeStatus = res.getHandshakeStatus(); - //TODO deal with BUFFER_OVERFLOW + // TODO deal with BUFFER_OVERFLOW if (status == SSLEngineResult.Status.CLOSED) { sslEngine.closeInbound(); @@ -254,22 +277,22 @@ public class NIOSSLTransport extends NIOTransport { handshakeInProgress = true; while (true) { switch (sslEngine.getHandshakeStatus()) { - case NEED_UNWRAP: - secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); - break; - case NEED_TASK: - Runnable task; - while ((task = sslEngine.getDelegatedTask()) != null) { - taskRunnerFactory.execute(task); - } - break; - case NEED_WRAP: - ((NIOOutputStream)buffOut).write(ByteBuffer.allocate(0)); - break; - case FINISHED: - case NOT_HANDSHAKING: - finishHandshake(); - return; + case NEED_UNWRAP: + secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); + break; + case NEED_TASK: + Runnable task; + while ((task = sslEngine.getDelegatedTask()) != null) { + taskRunnerFactory.execute(task); + } + break; + case NEED_WRAP: + ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0)); + break; + case FINISHED: + case NOT_HANDSHAKING: + finishHandshake(); + return; } } } @@ -295,14 +318,15 @@ public class NIOSSLTransport extends NIOTransport { } /** - * Overriding in order to add the client's certificates to ConnectionInfo Commmands. + * Overriding in order to add the client's certificates to ConnectionInfo Commands. * - * @param command The Command coming in. + * @param command + * The Command coming in. */ @Override public void doConsume(Object command) { if (command instanceof ConnectionInfo) { - ConnectionInfo connectionInfo = (ConnectionInfo)command; + ConnectionInfo connectionInfo = (ConnectionInfo) command; connectionInfo.setTransportContext(getPeerCertificates()); } super.doConsume(command); @@ -315,10 +339,13 @@ public class NIOSSLTransport extends NIOTransport { X509Certificate[] clientCertChain = null; try { - if (sslSession != null) { - clientCertChain = (X509Certificate[])sslSession.getPeerCertificates(); + if (sslEngine.getSession() != null) { + clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates(); } } catch (SSLPeerUnverifiedException e) { + if (LOG.isTraceEnabled()) { + LOG.trace("Failed to get peer certificates.", e); + } } return clientCertChain;