diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/PlainOrSslConnectionFactory.java b/jetty-server/src/main/java/org/eclipse/jetty/server/OptionalSslConnectionFactory.java similarity index 52% rename from jetty-server/src/main/java/org/eclipse/jetty/server/PlainOrSslConnectionFactory.java rename to jetty-server/src/main/java/org/eclipse/jetty/server/OptionalSslConnectionFactory.java index d3d525f5924..7b75a01917a 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/PlainOrSslConnectionFactory.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/OptionalSslConnectionFactory.java @@ -20,47 +20,51 @@ package org.eclipse.jetty.server; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import org.eclipse.jetty.io.AbstractConnection; import org.eclipse.jetty.io.Connection; import org.eclipse.jetty.io.EndPoint; import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; /** *

A ConnectionFactory whose connections detect whether the first bytes are - * TLS bytes and upgrades to either a TLS connection or to a plain connection.

+ * TLS bytes and upgrades to either a TLS connection or to another configurable + * connection.

*/ -public class PlainOrSslConnectionFactory extends AbstractConnectionFactory +public class OptionalSslConnectionFactory extends AbstractConnectionFactory { - private static final Logger LOG = Log.getLogger(PlainOrSslConnection.class); + private static final Logger LOG = Log.getLogger(OptionalSslConnection.class); private static final int TLS_ALERT_FRAME_TYPE = 0x15; private static final int TLS_HANDSHAKE_FRAME_TYPE = 0x16; + private static final int TLS_MAJOR_VERSION = 3; private final SslConnectionFactory sslConnectionFactory; - private final String plainProtocol; + private final String otherProtocol; /** - *

Creates a new plain or TLS ConnectionFactory.

- *

If {@code plainProtocol} is {@code null}, and the first bytes are not TLS, then - * {@link #unknownProtocol(ByteBuffer, EndPoint)} is called; applications may override its - * behavior (by default it closes the EndPoint) for example by writing a minimal response.

+ *

Creates a new ConnectionFactory whose connections can upgrade to TLS or another protocol.

+ *

If {@code otherProtocol} is {@code null}, and the first bytes are not TLS, then + * {@link #otherProtocol(ByteBuffer, EndPoint)} is called.

* * @param sslConnectionFactory The SslConnectionFactory to use if the first bytes are TLS - * @param plainProtocol the protocol of the ConnectionFactory to use if the first bytes are not TLS, or null. + * @param otherProtocol the protocol of the ConnectionFactory to use if the first bytes are not TLS, + * or null to explicitly handle the non-TLS case */ - public PlainOrSslConnectionFactory(SslConnectionFactory sslConnectionFactory, String plainProtocol) + public OptionalSslConnectionFactory(SslConnectionFactory sslConnectionFactory, String otherProtocol) { - super("plain|ssl"); + super("ssl|other"); this.sslConnectionFactory = sslConnectionFactory; - this.plainProtocol = plainProtocol; + this.otherProtocol = otherProtocol; } @Override public Connection newConnection(Connector connector, EndPoint endPoint) { - return configure(new PlainOrSslConnection(endPoint, connector), connector, endPoint); + return configure(new OptionalSslConnection(endPoint, connector), connector, endPoint); } /** @@ -70,39 +74,61 @@ public class PlainOrSslConnectionFactory extends AbstractConnectionFactory protected boolean seemsTLS(ByteBuffer buffer) { int tlsFrameType = buffer.get(0) & 0xFF; - return tlsFrameType == TLS_HANDSHAKE_FRAME_TYPE || tlsFrameType == TLS_ALERT_FRAME_TYPE; + int tlsMajorVersion = buffer.get(1) & 0xFF; + return (tlsFrameType == TLS_HANDSHAKE_FRAME_TYPE || tlsFrameType == TLS_ALERT_FRAME_TYPE) && tlsMajorVersion == TLS_MAJOR_VERSION; } /** - *

Callback method invoked when {@code plainProtocol} is {@code null} + *

Callback method invoked when {@code otherProtocol} is {@code null} * and the first bytes are not TLS.

*

This typically happens when a client is trying to connect to a TLS * port using the {@code http} scheme (and not the {@code https} scheme).

- *

This method may be overridden to write back a minimal response such as:

- *
-     * HTTP/1.1 400 Bad Request
-     * Content-Length: 35
-     * Content-Type: text/plain; charset=UTF8
-     * Connection: close
-     *
-     * Plain HTTP request sent to TLS port
-     * 
* * @param buffer The buffer with the first bytes of the connection * @param endPoint The connection EndPoint object * @see #seemsTLS(ByteBuffer) */ - protected void unknownProtocol(ByteBuffer buffer, EndPoint endPoint) + protected void otherProtocol(ByteBuffer buffer, EndPoint endPoint) { - endPoint.close(); + // There are always at least 2 bytes. + int byte1 = buffer.get(0) & 0xFF; + int byte2 = buffer.get(1) & 0xFF; + if (byte1 == 'G' && byte2 == 'E') + { + // Plain text HTTP to a HTTPS port, + // write a minimal response. + String body = "" + + "\r\n" + + "\r\n" + + "Bad Request\r\n" + + "" + + "

Bad Request

" + + "

HTTP request to HTTPS port

" + + "\r\n" + + ""; + String response = "" + + "HTTP/1.1 400 Bad Request\r\n" + + "Content-Type: text/html\r\n" + + "Content-Length: " + body.length() + "\r\n" + + "Connection: close\r\n" + + "\r\n" + + body; + Callback.Completable completable = new Callback.Completable(); + endPoint.write(completable, ByteBuffer.wrap(response.getBytes(StandardCharsets.US_ASCII))); + completable.whenComplete((r, x) -> endPoint.close()); + } + else + { + endPoint.close(); + } } - private class PlainOrSslConnection extends AbstractConnection implements Connection.UpgradeFrom + private class OptionalSslConnection extends AbstractConnection implements Connection.UpgradeFrom { private final Connector connector; private final ByteBuffer buffer; - public PlainOrSslConnection(EndPoint endPoint, Connector connector) + public OptionalSslConnection(EndPoint endPoint, Connector connector) { super(endPoint, connector.getExecutor()); this.connector = connector; @@ -121,18 +147,28 @@ public class PlainOrSslConnectionFactory extends AbstractConnectionFactory { try { - int filled = getEndPoint().fill(buffer); - if (filled > 0) + while (true) { - upgrade(buffer); - } - else if (filled == 0) - { - fillInterested(); - } - else - { - close(); + int filled = getEndPoint().fill(buffer); + if (filled > 0) + { + // Always have at least 2 bytes. + if (BufferUtil.length(buffer) >= 2) + { + upgrade(buffer); + break; + } + } + else if (filled == 0) + { + fillInterested(); + break; + } + else + { + close(); + break; + } } } catch (IOException x) @@ -162,27 +198,27 @@ public class PlainOrSslConnectionFactory extends AbstractConnectionFactory } else { - if (plainProtocol != null) + if (otherProtocol != null) { - ConnectionFactory connectionFactory = connector.getConnectionFactory(plainProtocol); + ConnectionFactory connectionFactory = connector.getConnectionFactory(otherProtocol); if (connectionFactory != null) { if (LOG.isDebugEnabled()) - LOG.debug("Detected plain bytes, upgrading to {}", connectionFactory); + LOG.debug("Detected non-TLS bytes, upgrading to {}", connectionFactory); Connection next = connectionFactory.newConnection(connector, endPoint); endPoint.upgrade(next); } else { - LOG.warn("Missing {} {} in {}", plainProtocol, ConnectionFactory.class.getSimpleName(), connector); + LOG.warn("Missing {} {} in {}", otherProtocol, ConnectionFactory.class.getSimpleName(), connector); close(); } } else { if (LOG.isDebugEnabled()) - LOG.debug("Detected plain bytes, but no configured protocol to upgrade to"); - unknownProtocol(buffer, endPoint); + LOG.debug("Detected non-TLS bytes, but no other protocol to upgrade to"); + otherProtocol(buffer, endPoint); } } } diff --git a/jetty-server/src/test/java/org/eclipse/jetty/server/PlainOrSslConnectionTest.java b/jetty-server/src/test/java/org/eclipse/jetty/server/OptionalSslConnectionTest.java similarity index 64% rename from jetty-server/src/test/java/org/eclipse/jetty/server/PlainOrSslConnectionTest.java rename to jetty-server/src/test/java/org/eclipse/jetty/server/OptionalSslConnectionTest.java index 6267190d014..7a8e7a00e7e 100644 --- a/jetty-server/src/test/java/org/eclipse/jetty/server/PlainOrSslConnectionTest.java +++ b/jetty-server/src/test/java/org/eclipse/jetty/server/OptionalSslConnectionTest.java @@ -22,7 +22,6 @@ import java.io.InputStream; import java.io.OutputStream; import java.net.InetSocketAddress; import java.net.Socket; -import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.function.Function; @@ -31,10 +30,8 @@ import javax.servlet.http.HttpServletResponse; import org.eclipse.jetty.http.HttpStatus; import org.eclipse.jetty.http.HttpTester; -import org.eclipse.jetty.io.EndPoint; import org.eclipse.jetty.server.handler.AbstractHandler; import org.eclipse.jetty.toolchain.test.MavenTestingUtils; -import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.util.thread.QueuedThreadPool; import org.junit.jupiter.api.AfterEach; @@ -43,12 +40,12 @@ import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; -public class PlainOrSslConnectionTest +public class OptionalSslConnectionTest { private Server server; private ServerConnector connector; - private void startServer(Function configFn, Handler handler) throws Exception + private void startServer(Function configFn, Handler handler) throws Exception { QueuedThreadPool serverThreads = new QueuedThreadPool(); serverThreads.setName("server"); @@ -63,8 +60,8 @@ public class PlainOrSslConnectionTest HttpConfiguration httpConfig = new HttpConfiguration(); HttpConnectionFactory http = new HttpConnectionFactory(httpConfig); SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol()); - PlainOrSslConnectionFactory plainOrSsl = configFn.apply(ssl); - connector = new ServerConnector(server, 1, 1, plainOrSsl, ssl, http); + OptionalSslConnectionFactory sslOrOther = configFn.apply(ssl); + connector = new ServerConnector(server, 1, 1, sslOrOther, ssl, http); server.addConnector(connector); server.setHandler(handler); @@ -79,34 +76,20 @@ public class PlainOrSslConnectionTest server.stop(); } - private PlainOrSslConnectionFactory plainOrSsl(SslConnectionFactory ssl) + private OptionalSslConnectionFactory optionalSsl(SslConnectionFactory ssl) { - return new PlainOrSslConnectionFactory(ssl, ssl.getNextProtocol()); + return new OptionalSslConnectionFactory(ssl, ssl.getNextProtocol()); } - private PlainOrSslConnectionFactory plainToSslWithReport(SslConnectionFactory ssl) + private OptionalSslConnectionFactory optionalSslNoOtherProtocol(SslConnectionFactory ssl) { - return new PlainOrSslConnectionFactory(ssl, null) - { - @Override - protected void unknownProtocol(ByteBuffer buffer, EndPoint endPoint) - { - String response = "" + - "HTTP/1.1 400 Bad Request\r\n" + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "\r\n"; - Callback.Completable callback = new Callback.Completable(); - endPoint.write(callback, ByteBuffer.wrap(response.getBytes(StandardCharsets.US_ASCII))); - callback.whenComplete((r, x) -> endPoint.close()); - } - }; + return new OptionalSslConnectionFactory(ssl, null); } @Test - public void testPlainOrSslConnection() throws Exception + public void testOptionalSslConnection() throws Exception { - startServer(this::plainOrSsl, new EmptyServerHandler()); + startServer(this::optionalSsl, new EmptyServerHandler()); String request = "" + "GET / HTTP/1.1\r\n" + @@ -152,9 +135,52 @@ public class PlainOrSslConnectionTest } @Test - public void testPlainToSslWithReport() throws Exception + public void testOptionalSslConnectionWithOnlyOneByteShouldIdleTimeout() throws Exception { - startServer(this::plainToSslWithReport, new EmptyServerHandler()); + startServer(this::optionalSsl, new EmptyServerHandler()); + long idleTimeout = 1000; + connector.setIdleTimeout(idleTimeout); + + try (Socket socket = new Socket()) + { + socket.connect(new InetSocketAddress("localhost", connector.getLocalPort()), 1000); + OutputStream output = socket.getOutputStream(); + output.write(0x16); + output.flush(); + + socket.setSoTimeout((int)(2 * idleTimeout)); + InputStream input = socket.getInputStream(); + int read = input.read(); + assertEquals(-1, read); + } + } + + @Test + public void testOptionalSslConnectionWithUnknownBytes() throws Exception + { + startServer(this::optionalSslNoOtherProtocol, new EmptyServerHandler()); + + try (Socket socket = new Socket()) + { + socket.connect(new InetSocketAddress("localhost", connector.getLocalPort()), 1000); + OutputStream output = socket.getOutputStream(); + output.write(0x00); + output.flush(); + Thread.sleep(500); + output.write(0x00); + output.flush(); + + socket.setSoTimeout(5000); + InputStream input = socket.getInputStream(); + int read = input.read(); + assertEquals(-1, read); + } + } + + @Test + public void testOptionalSslConnectionWithHTTPBytes() throws Exception + { + startServer(this::optionalSslNoOtherProtocol, new EmptyServerHandler()); String request = "" + "GET / HTTP/1.1\r\n" + @@ -162,17 +188,18 @@ public class PlainOrSslConnectionTest "\r\n"; byte[] requestBytes = request.getBytes(StandardCharsets.US_ASCII); - // Send a plain text HTTP request to SSL port: we should get back a minimal HTTP response. - try (Socket plain = new Socket()) + // Send a plain text HTTP request to SSL port, + // we should get back a minimal HTTP response. + try (Socket socket = new Socket()) { - plain.connect(new InetSocketAddress("localhost", connector.getLocalPort()), 1000); - OutputStream plainOutput = plain.getOutputStream(); - plainOutput.write(requestBytes); - plainOutput.flush(); + socket.connect(new InetSocketAddress("localhost", connector.getLocalPort()), 1000); + OutputStream output = socket.getOutputStream(); + output.write(requestBytes); + output.flush(); - plain.setSoTimeout(5000); - InputStream plainInput = plain.getInputStream(); - HttpTester.Response response = HttpTester.parseResponse(plainInput); + socket.setSoTimeout(5000); + InputStream input = socket.getInputStream(); + HttpTester.Response response = HttpTester.parseResponse(input); assertNotNull(response); assertEquals(HttpStatus.BAD_REQUEST_400, response.getStatus()); }