diff --git a/jetty-client/src/test/java/org/eclipse/jetty/client/HttpClientTLSTest.java b/jetty-client/src/test/java/org/eclipse/jetty/client/HttpClientTLSTest.java index 88fa504e269..a930c5d9f21 100644 --- a/jetty-client/src/test/java/org/eclipse/jetty/client/HttpClientTLSTest.java +++ b/jetty-client/src/test/java/org/eclipse/jetty/client/HttpClientTLSTest.java @@ -35,7 +35,9 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSocket; @@ -63,6 +65,7 @@ import org.eclipse.jetty.util.StringUtil; import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.util.thread.ExecutorThreadPool; import org.eclipse.jetty.util.thread.QueuedThreadPool; +import org.hamcrest.Matchers; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledOnJre; @@ -772,4 +775,52 @@ public class HttpClientTLSTest assertEquals(0, serverBytes.get()); assertEquals(0, clientBytes.get()); } + + @Test + public void testSSLEngineClosedDuringHandshake() throws Exception + { + SslContextFactory.Server serverTLSFactory = createServerSslContextFactory(); + startServer(serverTLSFactory, new EmptyServerHandler()); + + SslContextFactory.Client clientTLSFactory = createClientSslContextFactory(); + ClientConnector clientConnector = new ClientConnector(); + clientConnector.setSelectors(1); + clientConnector.setSslContextFactory(clientTLSFactory); + QueuedThreadPool clientThreads = new QueuedThreadPool(); + clientThreads.setName("client"); + clientConnector.setExecutor(clientThreads); + client = new HttpClient(new HttpClientTransportOverHTTP(clientConnector)) + { + @Override + protected ClientConnectionFactory newSslClientConnectionFactory(SslContextFactory.Client sslContextFactory, ClientConnectionFactory connectionFactory) + { + if (sslContextFactory == null) + sslContextFactory = getSslContextFactory(); + return new SslClientConnectionFactory(sslContextFactory, getByteBufferPool(), getExecutor(), connectionFactory) + { + @Override + protected SslConnection newSslConnection(ByteBufferPool byteBufferPool, Executor executor, EndPoint endPoint, SSLEngine engine) + { + return new SslConnection(byteBufferPool, executor, endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption()) + { + @Override + protected SSLEngineResult wrap(SSLEngine sslEngine, ByteBuffer[] input, ByteBuffer output) throws SSLException + { + sslEngine.closeOutbound(); + return super.wrap(sslEngine, input, output); + } + }; + } + }; + } + }; + client.setExecutor(clientThreads); + client.start(); + + ExecutionException failure = assertThrows(ExecutionException.class, () -> client.newRequest("localhost", connector.getLocalPort()) + .scheme(HttpScheme.HTTPS.asString()) + .send()); + Throwable cause = failure.getCause(); + assertThat(cause, Matchers.instanceOf(SSLHandshakeException.class)); + } } diff --git a/jetty-io/src/main/java/org/eclipse/jetty/io/ssl/SslConnection.java b/jetty-io/src/main/java/org/eclipse/jetty/io/ssl/SslConnection.java index 3f8b445fe07..745b286b8f1 100644 --- a/jetty-io/src/main/java/org/eclipse/jetty/io/ssl/SslConnection.java +++ b/jetty-io/src/main/java/org/eclipse/jetty/io/ssl/SslConnection.java @@ -358,6 +358,16 @@ public class SslConnection extends AbstractConnection implements Connection.Upgr _decryptedEndPoint.onFillableFail(cause == null ? new IOException() : cause); } + protected SSLEngineResult wrap(SSLEngine sslEngine, ByteBuffer[] input, ByteBuffer output) throws SSLException + { + return sslEngine.wrap(input, output); + } + + protected SSLEngineResult unwrap(SSLEngine sslEngine, ByteBuffer input, ByteBuffer output) throws SSLException + { + return sslEngine.unwrap(input, output); + } + @Override public String toConnectionString() { @@ -621,7 +631,7 @@ public class SslConnection extends AbstractConnection implements Connection.Upgr try { _underflown = false; - unwrapResult = _sslEngine.unwrap(_encryptedInput, appIn); + unwrapResult = unwrap(_sslEngine, _encryptedInput, appIn); } finally { @@ -696,8 +706,8 @@ public class SslConnection extends AbstractConnection implements Connection.Upgr } catch (Throwable x) { - Throwable failure = handleException(x, "fill"); - handshakeFailed(failure); + Throwable f = handleException(x, "fill"); + Throwable failure = handshakeFailed(f); if (_flushState == FlushState.WAIT_FOR_FILL) { _flushState = FlushState.IDLE; @@ -834,7 +844,7 @@ public class SslConnection extends AbstractConnection implements Connection.Upgr } } - private void handshakeFailed(Throwable failure) + private Throwable handshakeFailed(Throwable failure) { if (_handshake.compareAndSet(HandshakeState.HANDSHAKE, HandshakeState.FAILED)) { @@ -844,6 +854,7 @@ public class SslConnection extends AbstractConnection implements Connection.Upgr failure = new SSLHandshakeException(failure.getMessage()).initCause(failure); notifyHandshakeFailed(_sslEngine, failure); } + return failure; } private void terminateInput() @@ -958,7 +969,7 @@ public class SslConnection extends AbstractConnection implements Connection.Upgr SSLEngineResult wrapResult; try { - wrapResult = _sslEngine.wrap(appOuts, _encryptedOutput); + wrapResult = wrap(_sslEngine, appOuts, _encryptedOutput); } finally { @@ -1037,8 +1048,7 @@ public class SslConnection extends AbstractConnection implements Connection.Upgr catch (Throwable x) { Throwable failure = handleException(x, "flush"); - handshakeFailed(failure); - throw failure; + throw handshakeFailed(failure); } finally {