Issue #4217 - SslConnection.DecryptedEndpoint.flush eternal busy loop.

Updates after review.
Added test case.

Signed-off-by: Simone Bordet <simone.bordet@gmail.com>
This commit is contained in:
Simone Bordet 2019-10-19 20:06:10 +02:00
parent 991cf20cce
commit 73eb82c20f
2 changed files with 126 additions and 10 deletions

View File

@ -807,4 +807,120 @@ public class HttpClientTLSTest
Throwable cause = failure.getCause(); Throwable cause = failure.getCause();
assertThat(cause, Matchers.instanceOf(SSLHandshakeException.class)); assertThat(cause, Matchers.instanceOf(SSLHandshakeException.class));
} }
@Test
public void testTLSLargeFragments() throws Exception
{
CountDownLatch serverLatch = new CountDownLatch(1);
SslContextFactory serverTLSFactory = createServerSslContextFactory();
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.addCustomizer(new SecureRequestCustomizer());
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
SslConnectionFactory ssl = new SslConnectionFactory(serverTLSFactory, http.getProtocol())
{
@Override
protected SslConnection newSslConnection(Connector connector, EndPoint endPoint, SSLEngine engine)
{
return new SslConnection(connector.getByteBufferPool(), connector.getExecutor(), endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption())
{
@Override
protected SSLEngineResult unwrap(SSLEngine sslEngine, ByteBuffer input, ByteBuffer output) throws SSLException
{
int inputBytes = input.remaining();
SSLEngineResult result = super.unwrap(sslEngine, input, output);
if (inputBytes == 5)
serverLatch.countDown();
return result;
}
};
}
};
connector = new ServerConnector(server, 1, 1, ssl, http);
server.addConnector(connector);
server.setHandler(new EmptyServerHandler());
server.start();
long idleTimeout = 2000;
CountDownLatch clientLatch = new CountDownLatch(1);
SslContextFactory clientTLSFactory = createClientSslContextFactory();
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
client = new HttpClient(clientTLSFactory)
{
@Override
protected ClientConnectionFactory newSslClientConnectionFactory(SslContextFactory sslContextFactory, ClientConnectionFactory connectionFactory)
{
if (sslContextFactory == null)
sslContextFactory = getSslContextFactory();
return new SslClientConnectionFactory(sslContextFactory, getByteBufferPool(), getExecutor(), connectionFactory)
{
@Override
protected SslConnection newSslConnection(ByteBufferPool byteBufferPool, Executor executor, EndPoint endPoint, SSLEngine engine)
{
return new SslConnection(byteBufferPool, executor, endPoint, engine, isDirectBuffersForEncryption(), isDirectBuffersForDecryption())
{
@Override
protected SSLEngineResult wrap(SSLEngine sslEngine, ByteBuffer[] input, ByteBuffer output) throws SSLException
{
try
{
clientLatch.countDown();
assertTrue(serverLatch.await(5, TimeUnit.SECONDS));
return super.wrap(sslEngine, input, output);
}
catch (InterruptedException x)
{
throw new SSLException(x);
}
}
};
}
};
}
};
client.setIdleTimeout(idleTimeout);
client.setExecutor(clientThreads);
client.start();
String host = "localhost";
int port = connector.getLocalPort();
CountDownLatch responseLatch = new CountDownLatch(1);
client.newRequest(host, port)
.scheme(HttpScheme.HTTPS.asString())
.send(result ->
{
assertTrue(result.isSucceeded());
assertEquals(HttpStatus.OK_200, result.getResponse().getStatus());
responseLatch.countDown();
});
// Wait for the TLS buffers to be acquired by the client, then the
// HTTP request will be paused waiting for the TLS buffer to be expanded.
assertTrue(clientLatch.await(5, TimeUnit.SECONDS));
// Send the large frame bytes that will enlarge the TLS buffers.
try (Socket socket = new Socket(host, port))
{
OutputStream output = socket.getOutputStream();
byte[] largeFrameBytes = new byte[5];
largeFrameBytes[0] = 22; // Type = handshake
largeFrameBytes[1] = 3; // Major TLS version
largeFrameBytes[2] = 3; // Minor TLS version
// Frame length is 0x7FFF == 32767, i.e. a "large fragment".
// Maximum allowed by RFC 8446 is 16384, but SSLEngine supports up to 33093.
largeFrameBytes[3] = 0x7F; // Length hi byte
largeFrameBytes[4] = (byte)0xFF; // Length lo byte
output.write(largeFrameBytes);
output.flush();
// Just close the connection now, the large frame
// length was enough to trigger the buffer expansion.
}
// The HTTP request will resume and be forced to handle the TLS buffer expansion.
assertTrue(responseLatch.await(5, TimeUnit.SECONDS));
}
} }

View File

@ -25,6 +25,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.ToIntFunction;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLEngineResult.HandshakeStatus;
@ -311,23 +312,22 @@ public class SslConnection extends AbstractConnection implements Connection.Upgr
private int getApplicationBufferSize() private int getApplicationBufferSize()
{ {
SSLSession hsSession = _sslEngine.getHandshakeSession(); return getBufferSize(SSLSession::getApplicationBufferSize);
SSLSession session = _sslEngine.getSession();
int size = session.getApplicationBufferSize();
if (hsSession == null)
return size;
int hsSize = hsSession.getApplicationBufferSize();
return Math.max(hsSize, size);
} }
private int getPacketBufferSize() private int getPacketBufferSize()
{
return getBufferSize(SSLSession::getPacketBufferSize);
}
private int getBufferSize(ToIntFunction<SSLSession> bufferSizeFn)
{ {
SSLSession hsSession = _sslEngine.getHandshakeSession(); SSLSession hsSession = _sslEngine.getHandshakeSession();
SSLSession session = _sslEngine.getSession(); SSLSession session = _sslEngine.getSession();
int size = session.getPacketBufferSize(); int size = bufferSizeFn.applyAsInt(session);
if (hsSession == null) if (hsSession == null || hsSession == session)
return size; return size;
int hsSize = hsSession.getPacketBufferSize(); int hsSize = bufferSizeFn.applyAsInt(hsSession);
return Math.max(hsSize, size); return Math.max(hsSize, size);
} }