Fixes #3311 - Ability to serve HTTP and HTTPS from the same port.

Updated implementation and tests after reviews.

Signed-off-by: Simone Bordet <simone.bordet@gmail.com>
This commit is contained in:
Simone Bordet 2019-02-03 14:24:07 +01:00
parent d9855fb1bc
commit da490673af
2 changed files with 146 additions and 83 deletions

View File

@ -20,47 +20,51 @@ package org.eclipse.jetty.server;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import org.eclipse.jetty.io.AbstractConnection; import org.eclipse.jetty.io.AbstractConnection;
import org.eclipse.jetty.io.Connection; import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint; import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.util.log.Logger;
/** /**
* <p>A ConnectionFactory whose connections detect whether the first bytes are * <p>A ConnectionFactory whose connections detect whether the first bytes are
* TLS bytes and upgrades to either a TLS connection or to a plain connection.</p> * TLS bytes and upgrades to either a TLS connection or to another configurable
* connection.</p>
*/ */
public class PlainOrSslConnectionFactory extends AbstractConnectionFactory public class OptionalSslConnectionFactory extends AbstractConnectionFactory
{ {
private static final Logger LOG = Log.getLogger(PlainOrSslConnection.class); private static final Logger LOG = Log.getLogger(OptionalSslConnection.class);
private static final int TLS_ALERT_FRAME_TYPE = 0x15; private static final int TLS_ALERT_FRAME_TYPE = 0x15;
private static final int TLS_HANDSHAKE_FRAME_TYPE = 0x16; private static final int TLS_HANDSHAKE_FRAME_TYPE = 0x16;
private static final int TLS_MAJOR_VERSION = 3;
private final SslConnectionFactory sslConnectionFactory; private final SslConnectionFactory sslConnectionFactory;
private final String plainProtocol; private final String otherProtocol;
/** /**
* <p>Creates a new plain or TLS ConnectionFactory.</p> * <p>Creates a new ConnectionFactory whose connections can upgrade to TLS or another protocol.</p>
* <p>If {@code plainProtocol} is {@code null}, and the first bytes are not TLS, then * <p>If {@code otherProtocol} is {@code null}, and the first bytes are not TLS, then
* {@link #unknownProtocol(ByteBuffer, EndPoint)} is called; applications may override its * {@link #otherProtocol(ByteBuffer, EndPoint)} is called.</p>
* behavior (by default it closes the EndPoint) for example by writing a minimal response. </p>
* *
* @param sslConnectionFactory The SslConnectionFactory to use if the first bytes are TLS * @param sslConnectionFactory The SslConnectionFactory to use if the first bytes are TLS
* @param plainProtocol the protocol of the ConnectionFactory to use if the first bytes are not TLS, or null. * @param otherProtocol the protocol of the ConnectionFactory to use if the first bytes are not TLS,
* or null to explicitly handle the non-TLS case
*/ */
public PlainOrSslConnectionFactory(SslConnectionFactory sslConnectionFactory, String plainProtocol) public OptionalSslConnectionFactory(SslConnectionFactory sslConnectionFactory, String otherProtocol)
{ {
super("plain|ssl"); super("ssl|other");
this.sslConnectionFactory = sslConnectionFactory; this.sslConnectionFactory = sslConnectionFactory;
this.plainProtocol = plainProtocol; this.otherProtocol = otherProtocol;
} }
@Override @Override
public Connection newConnection(Connector connector, EndPoint endPoint) public Connection newConnection(Connector connector, EndPoint endPoint)
{ {
return configure(new PlainOrSslConnection(endPoint, connector), connector, endPoint); return configure(new OptionalSslConnection(endPoint, connector), connector, endPoint);
} }
/** /**
@ -70,39 +74,61 @@ public class PlainOrSslConnectionFactory extends AbstractConnectionFactory
protected boolean seemsTLS(ByteBuffer buffer) protected boolean seemsTLS(ByteBuffer buffer)
{ {
int tlsFrameType = buffer.get(0) & 0xFF; int tlsFrameType = buffer.get(0) & 0xFF;
return tlsFrameType == TLS_HANDSHAKE_FRAME_TYPE || tlsFrameType == TLS_ALERT_FRAME_TYPE; int tlsMajorVersion = buffer.get(1) & 0xFF;
return (tlsFrameType == TLS_HANDSHAKE_FRAME_TYPE || tlsFrameType == TLS_ALERT_FRAME_TYPE) && tlsMajorVersion == TLS_MAJOR_VERSION;
} }
/** /**
* <p>Callback method invoked when {@code plainProtocol} is {@code null} * <p>Callback method invoked when {@code otherProtocol} is {@code null}
* and the first bytes are not TLS.</p> * and the first bytes are not TLS.</p>
* <p>This typically happens when a client is trying to connect to a TLS * <p>This typically happens when a client is trying to connect to a TLS
* port using the {@code http} scheme (and not the {@code https} scheme).</p> * port using the {@code http} scheme (and not the {@code https} scheme).</p>
* <p>This method may be overridden to write back a minimal response such as:</p>
* <pre>
* HTTP/1.1 400 Bad Request
* Content-Length: 35
* Content-Type: text/plain; charset=UTF8
* Connection: close
*
* Plain HTTP request sent to TLS port
* </pre>
* *
* @param buffer The buffer with the first bytes of the connection * @param buffer The buffer with the first bytes of the connection
* @param endPoint The connection EndPoint object * @param endPoint The connection EndPoint object
* @see #seemsTLS(ByteBuffer) * @see #seemsTLS(ByteBuffer)
*/ */
protected void unknownProtocol(ByteBuffer buffer, EndPoint endPoint) protected void otherProtocol(ByteBuffer buffer, EndPoint endPoint)
{
// There are always at least 2 bytes.
int byte1 = buffer.get(0) & 0xFF;
int byte2 = buffer.get(1) & 0xFF;
if (byte1 == 'G' && byte2 == 'E')
{
// Plain text HTTP to a HTTPS port,
// write a minimal response.
String body = "" +
"<!DOCTYPE html>\r\n" +
"<html>\r\n" +
"<head><title>Bad Request</title></head>\r\n" +
"<body>" +
"<h1>Bad Request</h1>" +
"<p>HTTP request to HTTPS port</p>" +
"</body>\r\n" +
"</html>";
String response = "" +
"HTTP/1.1 400 Bad Request\r\n" +
"Content-Type: text/html\r\n" +
"Content-Length: " + body.length() + "\r\n" +
"Connection: close\r\n" +
"\r\n" +
body;
Callback.Completable completable = new Callback.Completable();
endPoint.write(completable, ByteBuffer.wrap(response.getBytes(StandardCharsets.US_ASCII)));
completable.whenComplete((r, x) -> endPoint.close());
}
else
{ {
endPoint.close(); endPoint.close();
} }
}
private class PlainOrSslConnection extends AbstractConnection implements Connection.UpgradeFrom private class OptionalSslConnection extends AbstractConnection implements Connection.UpgradeFrom
{ {
private final Connector connector; private final Connector connector;
private final ByteBuffer buffer; private final ByteBuffer buffer;
public PlainOrSslConnection(EndPoint endPoint, Connector connector) public OptionalSslConnection(EndPoint endPoint, Connector connector)
{ {
super(endPoint, connector.getExecutor()); super(endPoint, connector.getExecutor());
this.connector = connector; this.connector = connector;
@ -120,19 +146,29 @@ public class PlainOrSslConnectionFactory extends AbstractConnectionFactory
public void onFillable() public void onFillable()
{ {
try try
{
while (true)
{ {
int filled = getEndPoint().fill(buffer); int filled = getEndPoint().fill(buffer);
if (filled > 0) if (filled > 0)
{
// Always have at least 2 bytes.
if (BufferUtil.length(buffer) >= 2)
{ {
upgrade(buffer); upgrade(buffer);
break;
}
} }
else if (filled == 0) else if (filled == 0)
{ {
fillInterested(); fillInterested();
break;
} }
else else
{ {
close(); close();
break;
}
} }
} }
catch (IOException x) catch (IOException x)
@ -162,27 +198,27 @@ public class PlainOrSslConnectionFactory extends AbstractConnectionFactory
} }
else else
{ {
if (plainProtocol != null) if (otherProtocol != null)
{ {
ConnectionFactory connectionFactory = connector.getConnectionFactory(plainProtocol); ConnectionFactory connectionFactory = connector.getConnectionFactory(otherProtocol);
if (connectionFactory != null) if (connectionFactory != null)
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Detected plain bytes, upgrading to {}", connectionFactory); LOG.debug("Detected non-TLS bytes, upgrading to {}", connectionFactory);
Connection next = connectionFactory.newConnection(connector, endPoint); Connection next = connectionFactory.newConnection(connector, endPoint);
endPoint.upgrade(next); endPoint.upgrade(next);
} }
else else
{ {
LOG.warn("Missing {} {} in {}", plainProtocol, ConnectionFactory.class.getSimpleName(), connector); LOG.warn("Missing {} {} in {}", otherProtocol, ConnectionFactory.class.getSimpleName(), connector);
close(); close();
} }
} }
else else
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Detected plain bytes, but no configured protocol to upgrade to"); LOG.debug("Detected non-TLS bytes, but no other protocol to upgrade to");
unknownProtocol(buffer, endPoint); otherProtocol(buffer, endPoint);
} }
} }
} }

View File

@ -22,7 +22,6 @@ import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.Socket; import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.function.Function; import java.util.function.Function;
@ -31,10 +30,8 @@ import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpStatus; import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.HttpTester; import org.eclipse.jetty.http.HttpTester;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.server.handler.AbstractHandler; import org.eclipse.jetty.server.handler.AbstractHandler;
import org.eclipse.jetty.toolchain.test.MavenTestingUtils; import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.QueuedThreadPool; import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
@ -43,12 +40,12 @@ import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
public class PlainOrSslConnectionTest public class OptionalSslConnectionTest
{ {
private Server server; private Server server;
private ServerConnector connector; private ServerConnector connector;
private void startServer(Function<SslConnectionFactory, PlainOrSslConnectionFactory> configFn, Handler handler) throws Exception private void startServer(Function<SslConnectionFactory, OptionalSslConnectionFactory> configFn, Handler handler) throws Exception
{ {
QueuedThreadPool serverThreads = new QueuedThreadPool(); QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server"); serverThreads.setName("server");
@ -63,8 +60,8 @@ public class PlainOrSslConnectionTest
HttpConfiguration httpConfig = new HttpConfiguration(); HttpConfiguration httpConfig = new HttpConfiguration();
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig); HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol()); SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol());
PlainOrSslConnectionFactory plainOrSsl = configFn.apply(ssl); OptionalSslConnectionFactory sslOrOther = configFn.apply(ssl);
connector = new ServerConnector(server, 1, 1, plainOrSsl, ssl, http); connector = new ServerConnector(server, 1, 1, sslOrOther, ssl, http);
server.addConnector(connector); server.addConnector(connector);
server.setHandler(handler); server.setHandler(handler);
@ -79,34 +76,20 @@ public class PlainOrSslConnectionTest
server.stop(); server.stop();
} }
private PlainOrSslConnectionFactory plainOrSsl(SslConnectionFactory ssl) private OptionalSslConnectionFactory optionalSsl(SslConnectionFactory ssl)
{ {
return new PlainOrSslConnectionFactory(ssl, ssl.getNextProtocol()); return new OptionalSslConnectionFactory(ssl, ssl.getNextProtocol());
} }
private PlainOrSslConnectionFactory plainToSslWithReport(SslConnectionFactory ssl) private OptionalSslConnectionFactory optionalSslNoOtherProtocol(SslConnectionFactory ssl)
{ {
return new PlainOrSslConnectionFactory(ssl, null) return new OptionalSslConnectionFactory(ssl, null);
{
@Override
protected void unknownProtocol(ByteBuffer buffer, EndPoint endPoint)
{
String response = "" +
"HTTP/1.1 400 Bad Request\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n";
Callback.Completable callback = new Callback.Completable();
endPoint.write(callback, ByteBuffer.wrap(response.getBytes(StandardCharsets.US_ASCII)));
callback.whenComplete((r, x) -> endPoint.close());
}
};
} }
@Test @Test
public void testPlainOrSslConnection() throws Exception public void testOptionalSslConnection() throws Exception
{ {
startServer(this::plainOrSsl, new EmptyServerHandler()); startServer(this::optionalSsl, new EmptyServerHandler());
String request = "" + String request = "" +
"GET / HTTP/1.1\r\n" + "GET / HTTP/1.1\r\n" +
@ -152,9 +135,52 @@ public class PlainOrSslConnectionTest
} }
@Test @Test
public void testPlainToSslWithReport() throws Exception public void testOptionalSslConnectionWithOnlyOneByteShouldIdleTimeout() throws Exception
{ {
startServer(this::plainToSslWithReport, new EmptyServerHandler()); startServer(this::optionalSsl, new EmptyServerHandler());
long idleTimeout = 1000;
connector.setIdleTimeout(idleTimeout);
try (Socket socket = new Socket())
{
socket.connect(new InetSocketAddress("localhost", connector.getLocalPort()), 1000);
OutputStream output = socket.getOutputStream();
output.write(0x16);
output.flush();
socket.setSoTimeout((int)(2 * idleTimeout));
InputStream input = socket.getInputStream();
int read = input.read();
assertEquals(-1, read);
}
}
@Test
public void testOptionalSslConnectionWithUnknownBytes() throws Exception
{
startServer(this::optionalSslNoOtherProtocol, new EmptyServerHandler());
try (Socket socket = new Socket())
{
socket.connect(new InetSocketAddress("localhost", connector.getLocalPort()), 1000);
OutputStream output = socket.getOutputStream();
output.write(0x00);
output.flush();
Thread.sleep(500);
output.write(0x00);
output.flush();
socket.setSoTimeout(5000);
InputStream input = socket.getInputStream();
int read = input.read();
assertEquals(-1, read);
}
}
@Test
public void testOptionalSslConnectionWithHTTPBytes() throws Exception
{
startServer(this::optionalSslNoOtherProtocol, new EmptyServerHandler());
String request = "" + String request = "" +
"GET / HTTP/1.1\r\n" + "GET / HTTP/1.1\r\n" +
@ -162,17 +188,18 @@ public class PlainOrSslConnectionTest
"\r\n"; "\r\n";
byte[] requestBytes = request.getBytes(StandardCharsets.US_ASCII); byte[] requestBytes = request.getBytes(StandardCharsets.US_ASCII);
// Send a plain text HTTP request to SSL port: we should get back a minimal HTTP response. // Send a plain text HTTP request to SSL port,
try (Socket plain = new Socket()) // we should get back a minimal HTTP response.
try (Socket socket = new Socket())
{ {
plain.connect(new InetSocketAddress("localhost", connector.getLocalPort()), 1000); socket.connect(new InetSocketAddress("localhost", connector.getLocalPort()), 1000);
OutputStream plainOutput = plain.getOutputStream(); OutputStream output = socket.getOutputStream();
plainOutput.write(requestBytes); output.write(requestBytes);
plainOutput.flush(); output.flush();
plain.setSoTimeout(5000); socket.setSoTimeout(5000);
InputStream plainInput = plain.getInputStream(); InputStream input = socket.getInputStream();
HttpTester.Response response = HttpTester.parseResponse(plainInput); HttpTester.Response response = HttpTester.parseResponse(input);
assertNotNull(response); assertNotNull(response);
assertEquals(HttpStatus.BAD_REQUEST_400, response.getStatus()); assertEquals(HttpStatus.BAD_REQUEST_400, response.getStatus());
} }