diff --git a/src/main/java/org/elasticsearch/shield/ssl/AbstractSSLService.java b/src/main/java/org/elasticsearch/shield/ssl/AbstractSSLService.java index 5de1a53ed0d..f59cb7ebfb2 100644 --- a/src/main/java/org/elasticsearch/shield/ssl/AbstractSSLService.java +++ b/src/main/java/org/elasticsearch/shield/ssl/AbstractSSLService.java @@ -22,6 +22,9 @@ import org.elasticsearch.shield.ShieldSettingsException; import javax.net.ssl.*; import java.io.InputStream; import java.nio.file.Files; +import java.io.IOException; +import java.net.InetAddress; +import java.net.Socket; import java.security.KeyStore; import java.util.ArrayList; import java.util.Arrays; @@ -56,7 +59,8 @@ public abstract class AbstractSSLService extends AbstractComponent { * @return A SSLSocketFactory (for client-side SSL handshaking) */ public SSLSocketFactory sslSocketFactory() { - return sslContext(Settings.EMPTY).getSocketFactory(); + SSLSocketFactory socketFactory = sslContext().getSocketFactory(); + return new ShieldSSLSocketFactory(socketFactory, supportedProtocols(), supportedCiphers(socketFactory.getSupportedCipherSuites(), ciphers())); } public String[] supportedProtocols() { @@ -317,4 +321,77 @@ public abstract class AbstractSSLService extends AbstractComponent { return result; } } + + /** + * This socket factory set the protocols and ciphers on each SSLSocket after it is created + */ + static class ShieldSSLSocketFactory extends SSLSocketFactory { + + private final SSLSocketFactory delegate; + private final String[] supportedProtocols; + private final String[] ciphers; + + ShieldSSLSocketFactory(SSLSocketFactory delegate, String[] supportedProtocols, String[] ciphers) { + this.delegate = delegate; + this.supportedProtocols = supportedProtocols; + this.ciphers = ciphers; + } + + @Override + public String[] getDefaultCipherSuites() { + return ciphers; + } + + @Override + public String[] getSupportedCipherSuites() { + return delegate.getSupportedCipherSuites(); + } + + @Override + public Socket createSocket() throws IOException { + SSLSocket sslSocket = (SSLSocket) delegate.createSocket(); + configureSSLSocket(sslSocket); + return sslSocket; + } + + @Override + public Socket createSocket(Socket socket, String host, int port, boolean autoClose) throws IOException { + SSLSocket sslSocket = (SSLSocket) delegate.createSocket(socket, host, port, autoClose); + configureSSLSocket(sslSocket); + return sslSocket; + } + + @Override + public Socket createSocket(String host, int port) throws IOException { + SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port); + configureSSLSocket(sslSocket); + return sslSocket; + } + + @Override + public Socket createSocket(String host, int port, InetAddress localHost, int localPort) throws IOException { + SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port, localHost, localPort); + configureSSLSocket(sslSocket); + return sslSocket; + } + + @Override + public Socket createSocket(InetAddress host, int port) throws IOException { + SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port); + configureSSLSocket(sslSocket); + return sslSocket; + } + + @Override + public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) throws IOException { + SSLSocket sslSocket = (SSLSocket) delegate.createSocket(address, port, localAddress, localPort); + configureSSLSocket(sslSocket); + return sslSocket; + } + + private void configureSSLSocket(SSLSocket socket) { + socket.setEnabledProtocols(supportedProtocols); + socket.setEnabledCipherSuites(ciphers); + } + } } diff --git a/src/test/java/org/elasticsearch/shield/ssl/ClientSSLServiceTests.java b/src/test/java/org/elasticsearch/shield/ssl/ClientSSLServiceTests.java index c6a0bfbe68b..465d0c6b66e 100644 --- a/src/test/java/org/elasticsearch/shield/ssl/ClientSSLServiceTests.java +++ b/src/test/java/org/elasticsearch/shield/ssl/ClientSSLServiceTests.java @@ -17,10 +17,7 @@ import org.elasticsearch.test.junit.annotations.Network; import org.junit.Before; import org.junit.Test; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.SSLSessionContext; +import javax.net.ssl.*; import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; @@ -233,4 +230,19 @@ public class ClientSSLServiceTests extends ElasticsearchTestCase { .build(), env); sslService.createSSLEngine(); } + + @Test + public void testThatSSLSocketFactoryHasProperCiphersAndProtocols() throws Exception { + ClientSSLService sslService = new ClientSSLService(settingsBuilder() + .put("shield.ssl.keystore.path", testclientStore) + .put("shield.ssl.keystore.password", "testclient") + .build(), env); + SSLSocketFactory factory = sslService.sslSocketFactory(); + assertThat(factory.getDefaultCipherSuites(), is(sslService.ciphers())); + + try (SSLSocket socket = (SSLSocket) factory.createSocket()) { + assertThat(socket.getEnabledCipherSuites(), is(sslService.ciphers())); + assertThat(socket.getEnabledProtocols(), is(sslService.supportedProtocols())); + } + } } diff --git a/src/test/java/org/elasticsearch/shield/ssl/ServerSSLServiceTests.java b/src/test/java/org/elasticsearch/shield/ssl/ServerSSLServiceTests.java index 3bf6aa6a34c..4b995d66506 100644 --- a/src/test/java/org/elasticsearch/shield/ssl/ServerSSLServiceTests.java +++ b/src/test/java/org/elasticsearch/shield/ssl/ServerSSLServiceTests.java @@ -14,9 +14,7 @@ import org.elasticsearch.test.ElasticsearchTestCase; import org.junit.Before; import org.junit.Test; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLSessionContext; +import javax.net.ssl.*; import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; @@ -197,4 +195,19 @@ public class ServerSSLServiceTests extends ElasticsearchTestCase { .build(), settingsFilter, env); sslService.createSSLEngine(); } + + @Test + public void testThatSSLSocketFactoryHasProperCiphersAndProtocols() throws Exception { + ServerSSLService sslService = new ServerSSLService(settingsBuilder() + .put("shield.ssl.keystore.path", testnodeStore) + .put("shield.ssl.keystore.password", "testnode") + .build(), settingsFilter, env); + SSLSocketFactory factory = sslService.sslSocketFactory(); + assertThat(factory.getDefaultCipherSuites(), is(sslService.ciphers())); + + try (SSLSocket socket = (SSLSocket) factory.createSocket()) { + assertThat(socket.getEnabledCipherSuites(), is(sslService.ciphers())); + assertThat(socket.getEnabledProtocols(), is(sslService.supportedProtocols())); + } + } }