From 180d90c84731dcfe3a3c22f2c0b9ab6525046f8d Mon Sep 17 00:00:00 2001 From: Carter Kozak Date: Thu, 2 Nov 2023 10:19:52 -0400 Subject: [PATCH] HTTPCLIENT-2305: SSLConnectionSocketFactory allows socket.connect to be decorated (#499) --- .../testing/sync/TestSSLSocketFactory.java | 40 +++++++++++++ .../http/ssl/SSLConnectionSocketFactory.java | 56 +++++++++++++------ 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/httpclient5-testing/src/test/java/org/apache/hc/client5/testing/sync/TestSSLSocketFactory.java b/httpclient5-testing/src/test/java/org/apache/hc/client5/testing/sync/TestSSLSocketFactory.java index d0c0a3f13..cdcf7bf39 100644 --- a/httpclient5-testing/src/test/java/org/apache/hc/client5/testing/sync/TestSSLSocketFactory.java +++ b/httpclient5-testing/src/test/java/org/apache/hc/client5/testing/sync/TestSSLSocketFactory.java @@ -35,6 +35,7 @@ import java.net.Socket; import java.security.KeyManagementException; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; +import java.util.concurrent.atomic.AtomicBoolean; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLContext; @@ -57,6 +58,7 @@ import org.apache.hc.core5.io.CloseMode; import org.apache.hc.core5.ssl.SSLContexts; import org.apache.hc.core5.ssl.TrustStrategy; import org.apache.hc.core5.util.TimeValue; +import org.apache.hc.core5.util.Timeout; import org.hamcrest.CoreMatchers; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; @@ -123,6 +125,44 @@ public class TestSSLSocketFactory { } } + @Test + public void testBasicSslConnectOverride() throws Exception { + this.server = ServerBootstrap.bootstrap() + .setSslContext(SSLTestContexts.createServerSSLContext()) + .create(); + this.server.start(); + + final HttpContext context = new BasicHttpContext(); + final AtomicBoolean connectCalled = new AtomicBoolean(); + final SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( + SSLTestContexts.createClientSSLContext()) { + @Override + protected void connectSocket( + final Socket sock, + final InetSocketAddress remoteAddress, + final Timeout connectTimeout, + final HttpContext context) throws IOException { + connectCalled.set(true); + super.connectSocket(sock, remoteAddress, connectTimeout, context); + } + }; + try (final Socket socket = socketFactory.createSocket(context)) { + final InetSocketAddress remoteAddress = new InetSocketAddress("localhost", this.server.getLocalPort()); + final HttpHost target = new HttpHost("https", "localhost", this.server.getLocalPort()); + try (final SSLSocket sslSocket = (SSLSocket) socketFactory.connectSocket( + TimeValue.ZERO_MILLISECONDS, + socket, + target, + remoteAddress, + null, + context)) { + final SSLSession sslsession = sslSocket.getSession(); + Assertions.assertNotNull(sslsession); + Assertions.assertTrue(connectCalled.get()); + } + } + } + @Test public void testBasicDefaultHostnameVerifier() throws Exception { // @formatter:off diff --git a/httpclient5/src/main/java/org/apache/hc/client5/http/ssl/SSLConnectionSocketFactory.java b/httpclient5/src/main/java/org/apache/hc/client5/http/ssl/SSLConnectionSocketFactory.java index 353f4fae3..4e80b4fa7 100644 --- a/httpclient5/src/main/java/org/apache/hc/client5/http/ssl/SSLConnectionSocketFactory.java +++ b/httpclient5/src/main/java/org/apache/hc/client5/http/ssl/SSLConnectionSocketFactory.java @@ -32,6 +32,7 @@ import java.io.InputStream; import java.net.InetSocketAddress; import java.net.Proxy; import java.net.Socket; +import java.net.SocketAddress; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; @@ -236,22 +237,7 @@ public class SSLConnectionSocketFactory implements LayeredConnectionSocketFactor sock.bind(localAddress); } try { - if (LOG.isDebugEnabled()) { - LOG.debug("Connecting socket to {} with timeout {}", remoteAddress, connectTimeout); - } - // Run this under a doPrivileged to support lib users that run under a SecurityManager this allows granting connect permissions - // only to this library - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - sock.connect(remoteAddress, Timeout.defaultsToDisabled(connectTimeout).toMillisecondsIntBound()); - return null; - }); - } catch (final PrivilegedActionException e) { - Asserts.check(e.getCause() instanceof IOException, - "method contract violation only checked exceptions are wrapped: " + e.getCause()); - // only checked exceptions are wrapped - error and RTExceptions are rethrown by doPrivileged - throw (IOException) e.getCause(); - } + connectSocket(sock, remoteAddress, connectTimeout, context); } catch (final IOException ex) { Closer.closeQuietly(sock); throw ex; @@ -265,6 +251,44 @@ public class SSLConnectionSocketFactory implements LayeredConnectionSocketFactor return createLayeredSocket(sock, host.getHostName(), remoteAddress.getPort(), attachment, context); } + /** + * Connects the socket to the target host with the given resolved remote address using + * {@link Socket#connect(SocketAddress, int)}. This method may be overridden to customize + * how precisely {@link Socket#connect(SocketAddress, int)} is handled without impacting + * other connection establishment code within {@link #executeHandshake(SSLSocket, String, Object, HttpContext)}, + * for example. + * + * @param sock the socket to connect. + * @param remoteAddress the resolved remote address to connect to. + * @param connectTimeout connect timeout. + * @param context the actual HTTP context. + * @throws IOException if an I/O error occurs + */ + protected void connectSocket( + final Socket sock, + final InetSocketAddress remoteAddress, + final Timeout connectTimeout, + final HttpContext context) throws IOException { + Args.notNull(sock, "Socket"); + Args.notNull(remoteAddress, "Remote address"); + if (LOG.isDebugEnabled()) { + LOG.debug("Connecting socket to {} with timeout {}", remoteAddress, connectTimeout); + } + // Run this under a doPrivileged to support lib users that run under a SecurityManager this allows granting connect permissions + // only to this library + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + sock.connect(remoteAddress, Timeout.defaultsToDisabled(connectTimeout).toMillisecondsIntBound()); + return null; + }); + } catch (final PrivilegedActionException e) { + Asserts.check(e.getCause() instanceof IOException, + "method contract violation only checked exceptions are wrapped: " + e.getCause()); + // only checked exceptions are wrapped - error and RTExceptions are rethrown by doPrivileged + throw (IOException) e.getCause(); + } + } + @Override public Socket createLayeredSocket( final Socket socket,